A Gentle Introduction to TensorFlow.js
April 03, 2018
Posted by Zaid Alyafeai
TensorFlow js banner

Tensorflow.js is a library built on deeplearn.js to create deep learning modules directly on the browser. Using that you can create CNNs, RNNs , etc … on the browser and train these modules using the client’s GPU processing power. Hence, a server GPU is not needed to train the NN. This tutorial starts by explaining the basic building blocks of TensorFlow.js and the operations on them. Then, we describe how to create some complicated models.

One Note or Two …

If you want to play with the code I created an interactive coding session on Observable. Also, I created many mini-projects including simple classification, style transfer, pose estimation and pix2pix translation.

Getting Started

Since TensorFlow.js runs on the browser you just need to include the following script to the header of the html file
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
This will load the latest published version of the bundle.

Tensors (The building blocks)

If you are familiar with deep learning platforms like TensorFlow you should be able to recognize that tensors are n dimensional arrays that are consumed by operators. Hence they represent the building block for any deep learning application. Let us create a scalar tensor
const tensor = tf.scalar(2);
This created a scalar tensor. We also can convert arrays to tensors
const input = tf.tensor([2,2]);
This creates a constant tensor of the array [2,2]. In other words we converted the one dimensional array to a tensor by a applying the tensor function. We can use input.shape to retrieve the sizeof the tensor.
const tensor_s = tf.tensor([2,2]).shape;
This has the shape [2]. We can also create a tensor with specific size. For instance, here we create a tensor of zeros with shape [2,2].
const input = tf.zeros([2,2]);

Operators

In order to use tensors we need to create operations on them. Let us say we want to find the square of a tensor
const a = tf.tensor([1,2,3]);
a.square().print();
The value of x2 will be [4,9,16]. TensorFlow.js also allows chaining operations. For example, to evaluate the 2nd power of a tensor we use
const x = tf.tensor([1,2,3]);
const x2 = x.square().square();
The x2 tensor will have value [1,16,81].

Tensor Disposal

Usually we generate lots of intermediate tensors. For instance, in the previous example after evaluating x2 we don’t need the value of x. In order to do that we call dispose()
const x = tf.tensor([1,2,3]);
x.dispose();
Note that we can no longer use the tensor x in later operations. Now, it might be a little inconvenient to do that for every tensor. Actually, not disposing tensors will be an overhead for the memory. TensorFlow.js offers a special operator tidy() to dispose intermediary tensors automatically
function f(x)
{
 return tf.tidy(()=>{
  const y = x.square();
  const z = x.mul(y);
  return z
        });
}
Notice that the value of the tensor y will be disposed since we don’t need it after we evaluate the value of z.

Optimization problem

Here we learn how to solve an optimization problem. Given a function f(x) we are asked to evaluate x = a that minimizes f(x). To do that we will need an optimizer. An optimizer is an algorithm to minimize a function by following the gradient. There are many optimizers in the literature like SGD, Adam, etc… These optimizers differ in their speed and accuracy. Tensorflowjs support the most important optimizers.

We will take a simple example were f(x) = x⁶+2x⁴+3x²+x+1. The graph of the function is shown below. We see that the minmum of the function is in the interval [-0.5,0] . We will use an optimizer to find the exact value.
graph of the function f(x)
First we define the function to minimize
function f(x) 
{
  const f1 = x.pow(tf.scalar(6, 'int32')) //x^6
  const f2 = x.pow(tf.scalar(4, 'int32')).mul(tf.scalar(2)) //2x^4
  const f3 = x.pow(tf.scalar(2, 'int32')).mul(tf.scalar(3)) //3x^2
  const f4 = tf.scalar(1) //1
  return f1.add(f2).add(f3).add(x).add(f4)
}
Now we can iteratively minimize the function to find the value of the minimum. We will start by an initial value of a = 2. The learning rate defines how fast we jump to reach the minimum. We will use an Adam optimizer
function minimize(epochs , lr)
{
  let y = tf.variable(tf.scalar(2)) //initial value 
  const optim = tf.train.adam(lr);  //gadient descent algorithm 
  for(let i = 0 ; i < epochs ; i++) //start minimiziation 
    optim.minimize(() => f(y));
  return y 
}
Using a learning rate with value 0.9 we find the value of the minimum after 200 iterations to be -0.16092407703399658.

A Simple Neural Network

Now we learn how to create a neural network to learn XOR which is a nonlinear operation. The code is similar to keras implementation. We first create the training set which takes two inputs and one output. We will feed a batch of 4 items in each iteration
xs = tf.tensor2d([[0,0],[0,1],[1,0],[1,1]])
ys = tf.tensor2d([[0],[1],[1],[0]])
Then we create two dense layers with two different nonlinear activation functions. We use stochastic gradient descent with cross entropy loss. The learning rate is 0.1
function createModel()
{
  var model = tf.sequential()
  model.add(tf.layers.dense({units:8, inputShape:2, activation: 'tanh'}))
  model.add(tf.layers.dense({units:1, activation: 'sigmoid'}))
  model.compile({optimizer: 'sgd', loss: 'binaryCrossentropy', lr:0.1})
  return model
}
Then we fit the model for 5000iterations
  await model.fit(xs, ys, {
       batchSize: 1,
       epochs: 5000
   })
Finally we predict on the training set
model.predict(xs).print()
the output should be [[0.0064339], [0.9836861], [0.9835356], [0.0208658]] which should be expected.

CNN Model

TensorFlow.js uses automatic differentiation using computational graphs. We just need to create the layers, optimizer and compile the model. Let us create a sequential model
model = tf.sequential();
Now we can add different layers for the model. Let us add the first convolutional layer with input [28,28,1]
const convlayer = tf.layers.conv2d({
  inputShape: [28,28,1],
  kernelSize: 5,
  filters: 8,
  strides: 1,
  activation: 'relu',
  kernelInitializer: 'VarianceScaling'
});
Here we created a conv layer that takes input of size [28,28,1]. The input will be a gray image of size 28 x 28. Then we apply 8 kernels of size 5x5 and stride equals to 1 initialized with VarianceScaling. After that, we apply an activation function which basically takes the negative values in the tensor and replaces them with zeros. Now we can add this convlayer to the model
model.add(convlayer);
Now what is nice about Tensorflow.js we don’t need to specify the input size for the next layer as it will be evaluated automatically after we compile the model. We can also add max-pooling, dense layers , so on. Here is a simple model
const model = tf.sequential();

//create the first layer 
model.add(tf.layers.conv2d({
  inputShape: [28, 28, 1],
  kernelSize: 5,
  filters: 8,
  strides: 1,
  activation: 'relu',
  kernelInitializer: 'VarianceScaling'
}));

//create a max pooling layer 
model.add(tf.layers.maxPooling2d({
  poolSize: [2, 2],
  strides: [2, 2]
}));

//create the second conv layer
model.add(tf.layers.conv2d({
  kernelSize: 5,
  filters: 16,
  strides: 1,
  activation: 'relu',
  kernelInitializer: 'VarianceScaling'
}));

//create a max pooling layer 
model.add(tf.layers.maxPooling2d({
  poolSize: [2, 2],
  strides: [2, 2]
}));

//flatten the layers to use it for the dense layers 
model.add(tf.layers.flatten());

//dense layer with output 10 units 
model.add(tf.layers.dense({
  units: 10,
  kernelInitializer: 'VarianceScaling',
  activation: 'softmax'
}));
We can apply a tensor for any layer to inspect the output tensor. But here is a catch the input needs to be of shape [BATCH_SIZE,28,28,1] where BATCH_SIZE represents the number of dataset elements we apply to the model at a time. Here is an example of how to evaluate a convolutional layer
const convlayer = tf.layers.conv2d({
  inputShape: [28, 28, 1],
  kernelSize: 5,
  filters: 8,
  strides: 1,
  activation: 'relu',
  kernelInitializer: 'VarianceScaling'
});

const input = tf.zeros([1,28,28,1]);
const output = convlayer.apply(input);
After inspecting the shape of the output tensor we see it has shape[1,24,24,8]. This is evaluated using the formula
const outputSize = Math.floor((inputSize-kernelSize)/stride +1);
Which will result in 24 in our case. Returning to our model we realize that we used flatten() which basically convert the input from the shape [BATCH_SIZE,a,b,c] to the shape [BATCH_SIZE,axbxc]. This is important because in the dense layers we cannot apply 2d arrays. Finally, we used the dense layer with output units 10 which represents the number of classes we need in our recognition system. Actually this model is used for the recognizing hand-written digits in the so called MNIST dataset.

Optimization and Compilation

After creating the model we need a way to optimize the parameters. There are different approaches to do that like SGD and Adam optimizer. For instance, we can create an optimizer using
const LEARNING_RATE = 0.0001;
const optimizer = tf.train.adam(LEARNING_RATE);
This will create an Adam optimizer using the specified learning rate. Now, we are ready to compile the model (attaching the model with the optimizer)
model.compile({
  optimizer: optimizer,
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy'],
});
Here we created model that uses Adam to optimize the loss function that evaluates a cross entropy of the predicted output and the true label.

Training

After compiling the model we are ready to train the model on a dataset. We need to use the fit() function for that
const batch = tf.zeros([BATCH_SIZE,28,28,1]);
const labels = tf.zeros([BATCH_SIZE, NUM_CLASSES]);

const h = await model.fit(batch, labels,
            {
              batchSize: BATCH_SIZE,
              validationData: validationData,
              epochs: BATCH_EPOCHs
            });
Note that we are feeding to the fit function a batch of training set. The second variable for the fit function represents the true labels of the model. Lastly, we have the configuration parameters like the batchSize and epochs. Note that epochs represents how many times we iterate over the current batch NOT the whole dataset. Hence we can for example wrap that code inside a for loop that iterates over all the batches of the training set.

Note that we used the special keyword await which basically blocks and waits for the function to finish executing the code. It is like running a another thread and the main thread is waiting for the fitting function to finish execution.

One Hot Encoding

Usually the given labels are numbers which represents the class. For instance, suppose we have two classes an orange class and an apple class. Then we will give the orange class label 0 and the apple class label 1. But, our network accepts a tensor of size [BATCH_SIZE,NUM_CLASSES]. Hence we need to use what we call one hot encoding
const output = tf.oneHot(tf.tensor1d([0,1,0]), 2);

//the output will be [[1, 0],[0, 1],[1, 0]]
Hence we converted the 1d tensor off labels into a tensor of shape [BATCH_SIZE,NUM_CLASSES].

Loss and Accuracy

In order to inspect the performance of our model we need to know the loss and the accuracy. In order to do that we need to fetch the results of the model using the history module
//h is the output of the fitting module
const loss = h.history.loss[0];
const accuracy = h.history.acc[0];
Note that we are evaluating the loss and accuracy of the validationData that was an input to the fit() function.

Prediction

Suppose that we are done with training a model and it gives good loss and accuracy. It is time to predict the results of unseen data element. Suppose we are given an image that is in our browser or we took directly from our webcam, then we can use our trained model to predict its class. First, we need to convert the image into tensor
//retrieve the canvas
const canvas = document.getElementById("myCanvas");
const ctx = canvas.getContext("2d");

//get image data
imageData = ctx.getImageData(0, 0, 28, 28);

//convert to tensor 
const tensor = tf.fromPixels(imageData);
Here we created a canvas and retrieved imageData from it and then we converted to a tensor. Now the tensor will have size [28,28,3] but the model takes 4-dimensional vectors. Hence we need to add an extra dimension for the tensor using expandDims
const eTensor = tensor.expandDims(0);
Hence the output tensor will have size [1,28,28,3] since we have added a dimension at index 0. Now for prediction we simply use predict()
model.predict(eTensor);
The function predict will return the value of the last layer in our network usually a softmax activation function.

Transfer Learning

In the previous sections we had to train our model from scratch. However, this is an expensive operation since it requires more training iterations. Hence, we use a pretrained model called mobilenet. It is a light-wight CNN that is optimized to run in mobile applications. Mobilenet was trained on ImageNet classes. Basically we have precomputed activations trained on 1,000 different classes.

To load the model we use the following
const mobilenet = await tf.loadModel(
      'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');
We can use inputs , outputs to inspect the structure of the model
//The input size is [null, 224, 224, 3]
const input_s = mobilenet.inputs[0].shape;

//The output size is [null, 1000]
const output_s = mobilenet.outputs[0].shape;
Hence we need images of size [1,224,224,3] and the output will be a tensor of size [1,1000] which holds the probability of each class in the ImageNet dataset.

For the sake of simplicity we will take an array of zeros and try to predict the class number out of 1,000 classes
var pred = mobilenet.predict(tf.zeros([1, 224, 224, 3]));
pred.argMax().print();
After running the code I get class = 21 which represents a kite o:

Now we need to inspect the contents of the model. To do that we can get the models layers and names
//The number of layers in the model '88'
const len = mobilenet.layers.length;

//this outputs the name of the 3rd layer 'conv1_relu'
const name3 = mobilenet.layers[3].name;
We see that we have 88 layers which would be so expensive to train again on another dataset. Hence, the basic trick is to use this model just to evaluate the activations (we will not retrain) but we will create dense layers that we can train on another number of classes.

For instance, suppose we need a model to differentiate between carrots and cucumbers. We will use mobilene tmodel to calculate the activation up to some layer we choose. Then we use dense layers with output size 2 to predict the correct class. Hence the mobilenet model will be in some sense ‘freezed’ and we just train the dense layers.

First, we need to get rid of the dense layers of the model. We choose to extract a random layer let us say number 81 with name conv_pw_13_relu
const layer = mobilenet.getLayer('conv_pw_13_relu');
Now let us update our model to have this layer is an output
mobilenet = tf.model({inputs: mobilenet.inputs, outputs: layer.output});
Finally, we create the trainable model but we need to know the last layer output shape
//this outputs a layer of size [null, 7, 7, 256]
const layerOutput = layer.output.shape;
We see that the shape [null,7,7,256] Now we can input this to our dense layers
 trainableModel = tf.sequential({
    layers: [
      tf.layers.flatten({inputShape: [7, 7, 256]}),
      tf.layers.dense({
        units: 100,
        activation: 'relu',
        kernelInitializer: 'varianceScaling',
        useBias: true
      }),
      tf.layers.dense({
        units: 2,
        kernelInitializer: 'varianceScaling',
        useBias: false,
        activation: 'softmax'
      })
    ]
  });
As you can see we created a dense layer with 100 neurons and the output layer with size 2.
const activation = mobilenet.predict(input);
const predictions = trainableModel.predict(activation);
And we can use the previous sections to train the last model using a certain optimizer.

References

Next post
A Gentle Introduction to TensorFlow.js

Posted by Zaid Alyafeai


Tensorflow.js is a library built on deeplearn.js to create deep learning modules directly on the browser. Using that you can create CNNs, RNNs , etc … on the browser and train these modules using the client’s GPU processing power. Hence, a server GPU is not needed to train the NN. This tutorial starts by explaining the basic building blocks of TensorFlow.js and the operatio…