In this article, I will briefly go through the how-to-use TensorFlow.js to train a model via CSV format data and export the model directly to a browser-compatible format. For this article, I will use the mnist-10 dataset from Kaggle which is encoded as CSV file.
Prepare Data
For the data, readers can directly download and view on this website. In a word, the mnist-10 dataset is for hand-written digits recognition, where each image is a black-white image with a size of 28 * 28. Below is a quick look of it:
In our case, we downloaded the dataset from Kaggle with two parts: train.csv and test.csv. Let’s have a look at the train.csv first by calling tf.data related functions:
This code will take one data from the dataset and print it out. As we can see from the console, the dataset has two parts xs which is a hashmap(dictionary) of each pixel and its corresponding value (ranges from 0–255) and ys which is the corresponding label.
Okay, firstly, we need to wrap xs and ys as tensors so that the TensorFlow.js can smoothly deal with them. Also, we want to normalize the data and reshape the tensor data into 28*28*1 shape:
Cool! Here we got a good training dataset and we can head for model building and training.
Model Training
Wait! Before we get into the training phase, we still need to define a model. For this simple task, we can use relatively simple networks defines below:
Then we can move to the training part, it will be super easy that we just need to call one command:
And booom! We got out model training on the track. Now let’s make some coffee and relax. It depends on your device, but normally it won’t take more than thirty minutes to finish.
Also, the amazing thing about the TensorFlow.js node is that you can use Tensorboard to directly see the visualization output just like the native TensorFlow:
Prediction
Waiting for 20 epochs, we get our model and then we can do the prediction wow!! And, do not forget that our mission is to submit the result to Kaggle. Thus, we also need to convert the ultimate result into a required csv file.
Above is the code for prediction and saving it for the required CSV format.
Then we can upload our result to see how it goes!
WoW!! We got a 98% score which is not bad! To improve our score. We can also have a try with data augmentation, restructure refine and hyperparameter adjustment. Which might be included in the next article.
For the whole code, please refer to my Github: https://github.com/WenheLI/training-in-tfjs-node
I will keep working on using tfjs for kaggle in this repo!
No comments:
Post a Comment