4月 03, 2018 —
Posted by Zaid Alyafeai
Tensorflow.js is a library built on deeplearn.js to create deep learning modules directly on the browser. Using that you can create CNNs, RNNs , etc … on the browser and train these modules using the client’s GPU processing power. Hence, a server GPU is not needed to train the NN. This tutorial starts by explaining the basic building blocks of TensorFlow.js and the operatio…
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
This will load the latest published version of the bundle.const tensor = tf.scalar(2);
This created a scalar tensor. We also can convert arrays to tensorsconst input = tf.tensor([2,2]);
This creates a constant tensor of the array [2,2]
. In other words we converted the one dimensional array to a tensor by a applying the tensor function. We can use input.shape
to retrieve the size
of the tensor.const tensor_s = tf.tensor([2,2]).shape;
This has the shape [2]
. We can also create a tensor with specific size. For instance, here we create a tensor of zeros with shape [2,2]
.const input = tf.zeros([2,2]);
const a = tf.tensor([1,2,3]);
a.square().print();
The value of x2
will be [4,9,16]
. TensorFlow.js also allows chaining operations. For example, to evaluate the 2nd power of a tensor we useconst x = tf.tensor([1,2,3]);
const x2 = x.square().square();
The x2
tensor will have value [1,16,81]
.x2
we don’t need the value of x
. In order to do that we call dispose()
const x = tf.tensor([1,2,3]);
x.dispose();
Note that we can no longer use the tensor x
in later operations. Now, it might be a little inconvenient to do that for every tensor. Actually, not disposing tensors will be an overhead for the memory. TensorFlow.js offers a special operator tidy()
to dispose intermediary tensors automaticallyfunction f(x)
{
return tf.tidy(()=>{
const y = x.square();
const z = x.mul(y);
return z
});
}
Notice that the value of the tensor y
will be disposed since we don’t need it after we evaluate the value of z
.[-0.5,0]
. We will use an optimizer to find the exact value.graph of the function f(x) |
function f(x)
{
const f1 = x.pow(tf.scalar(6, 'int32')) //x^6
const f2 = x.pow(tf.scalar(4, 'int32')).mul(tf.scalar(2)) //2x^4
const f3 = x.pow(tf.scalar(2, 'int32')).mul(tf.scalar(3)) //3x^2
const f4 = tf.scalar(1) //1
return f1.add(f2).add(f3).add(x).add(f4)
}
Now we can iteratively minimize the function to find the value of the minimum. We will start by an initial value of a = 2. The learning rate defines how fast we jump to reach the minimum. We will use an Adam optimizerfunction minimize(epochs , lr)
{
let y = tf.variable(tf.scalar(2)) //initial value
const optim = tf.train.adam(lr); //gadient descent algorithm
for(let i = 0 ; i < epochs ; i++) //start minimiziation
optim.minimize(() => f(y));
return y
}
Using a learning rate with value 0.9
we find the value of the minimum after 200
iterations to be -0.16092407703399658
.xs = tf.tensor2d([[0,0],[0,1],[1,0],[1,1]])
ys = tf.tensor2d([[0],[1],[1],[0]])
Then we create two dense layers with two different nonlinear activation functions. We use stochastic gradient descent with cross entropy loss. The learning rate is 0.1
function createModel()
{
var model = tf.sequential()
model.add(tf.layers.dense({units:8, inputShape:2, activation: 'tanh'}))
model.add(tf.layers.dense({units:1, activation: 'sigmoid'}))
model.compile({optimizer: 'sgd', loss: 'binaryCrossentropy', lr:0.1})
return model
}
Then we fit the model for 5000iterations
await model.fit(xs, ys, {
batchSize: 1,
epochs: 5000
})
Finally we predict on the training setmodel.predict(xs).print()
the output should be [[0.0064339], [0.9836861], [0.9835356], [0.0208658]]
which should be expected.model = tf.sequential();
Now we can add different layers for the model. Let us add the first convolutional layer with input [28,28,1]
const convlayer = tf.layers.conv2d({
inputShape: [28,28,1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'VarianceScaling'
});
Here we created a conv
layer that takes input of size [28,28,1]
. The input will be a gray image of size 28 x 28
. Then we apply 8
kernels of size 5x5
and stride equals to 1
initialized with VarianceScaling
. After that, we apply an activation function which basically takes the negative values in the tensor and replaces them with zeros. Now we can add this convlayer to the modelmodel.add(convlayer);
Now what is nice about Tensorflow.js we don’t need to specify the input size for the next layer as it will be evaluated automatically after we compile the model. We can also add max-pooling, dense layers , so on. Here is a simple modelconst model = tf.sequential();
//create the first layer
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'VarianceScaling'
}));
//create a max pooling layer
model.add(tf.layers.maxPooling2d({
poolSize: [2, 2],
strides: [2, 2]
}));
//create the second conv layer
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'VarianceScaling'
}));
//create a max pooling layer
model.add(tf.layers.maxPooling2d({
poolSize: [2, 2],
strides: [2, 2]
}));
//flatten the layers to use it for the dense layers
model.add(tf.layers.flatten());
//dense layer with output 10 units
model.add(tf.layers.dense({
units: 10,
kernelInitializer: 'VarianceScaling',
activation: 'softmax'
}));
We can apply a tensor for any layer to inspect the output tensor. But here is a catch the input needs to be of shape [BATCH_SIZE,28,28,1]
where BATCH_SIZE
represents the number of dataset elements we apply to the model at a time. Here is an example of how to evaluate a convolutional layerconst convlayer = tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'VarianceScaling'
});
const input = tf.zeros([1,28,28,1]);
const output = convlayer.apply(input);
After inspecting the shape of the output
tensor we see it has shape[1,24,24,8]
. This is evaluated using the formulaconst outputSize = Math.floor((inputSize-kernelSize)/stride +1);
Which will result in 24
in our case. Returning to our model we realize that we used flatten()
which basically convert the input from the shape [BATCH_SIZE,a,b,c]
to the shape [BATCH_SIZE,axbxc]
. This is important because in the dense layers we cannot apply 2d
arrays. Finally, we used the dense layer with output units 10
which represents the number of classes we need in our recognition system. Actually this model is used for the recognizing hand-written digits in the so called MNIST dataset.const LEARNING_RATE = 0.0001;
const optimizer = tf.train.adam(LEARNING_RATE);
This will create an Adam optimizer using the specified learning rate. Now, we are ready to compile the model (attaching the model with the optimizer)model.compile({
optimizer: optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
Here we created model that uses Adam to optimize the loss function that evaluates a cross entropy of the predicted output and the true label.fit()
function for thatconst batch = tf.zeros([BATCH_SIZE,28,28,1]);
const labels = tf.zeros([BATCH_SIZE, NUM_CLASSES]);
const h = await model.fit(batch, labels,
{
batchSize: BATCH_SIZE,
validationData: validationData,
epochs: BATCH_EPOCHs
});
Note that we are feeding to the fit function a batch of training set. The second variable for the fit
function represents the true labels of the model. Lastly, we have the configuration parameters like the batchSize
and epochs
. Note that epochs
represents how many times we iterate over the current batch NOT the whole dataset. Hence we can for example wrap that code inside a for
loop that iterates over all the batches of the training set.await
which basically blocks and waits for the function to finish executing the code. It is like running a another thread and the main thread is waiting for the fitting function to finish execution.0
and the apple class label 1
. But, our network accepts a tensor of size [BATCH_SIZE,NUM_CLASSES]
. Hence we need to use what we call one hot encodingconst output = tf.oneHot(tf.tensor1d([0,1,0]), 2);
//the output will be [[1, 0],[0, 1],[1, 0]]
Hence we converted the 1d
tensor off labels into a tensor of shape [BATCH_SIZE,NUM_CLASSES]
.//h is the output of the fitting module
const loss = h.history.loss[0];
const accuracy = h.history.acc[0];
Note that we are evaluating the loss and accuracy of the validationData
that was an input to the fit()
function.//retrieve the canvas
const canvas = document.getElementById("myCanvas");
const ctx = canvas.getContext("2d");
//get image data
imageData = ctx.getImageData(0, 0, 28, 28);
//convert to tensor
const tensor = tf.fromPixels(imageData);
Here we created a canvas
and retrieved imageData
from it and then we converted to a tensor. Now the tensor will have size [28,28,3]
but the model takes 4-dimensional vectors. Hence we need to add an extra dimension for the tensor using expandDims
const eTensor = tensor.expandDims(0);
Hence the output tensor will have size [1,28,28,3]
since we have added a dimension at index 0
. Now for prediction we simply use predict()
model.predict(eTensor);
The function predict
will return the value of the last layer in our network usually a softmax
activation function.1,000
different classes.const mobilenet = await tf.loadModel(
'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');
We can use inputs , outputs to inspect the structure of the model//The input size is [null, 224, 224, 3]
const input_s = mobilenet.inputs[0].shape;
//The output size is [null, 1000]
const output_s = mobilenet.outputs[0].shape;
Hence we need images of size [1,224,224,3]
and the output will be a tensor of size [1,1000]
which holds the probability of each class in the ImageNet dataset.1,000
classesvar pred = mobilenet.predict(tf.zeros([1, 224, 224, 3]));
pred.argMax().print();
After running the code I get class = 21 which represents a kite o://The number of layers in the model '88'
const len = mobilenet.layers.length;
//this outputs the name of the 3rd layer 'conv1_relu'
const name3 = mobilenet.layers[3].name;
We see that we have 88
layers which would be so expensive to train again on another dataset. Hence, the basic trick is to use this model just to evaluate the activations (we will not retrain) but we will create dense layers that we can train on another number of classes.2
to predict the correct class. Hence the mobilenet
model will be in some sense ‘freezed’ and we just train the dense layers.81
with name conv_pw_13_relu
const layer = mobilenet.getLayer('conv_pw_13_relu');
Now let us update our model to have this layer is an outputmobilenet = tf.model({inputs: mobilenet.inputs, outputs: layer.output});
Finally, we create the trainable model but we need to know the last layer output shape//this outputs a layer of size [null, 7, 7, 256]
const layerOutput = layer.output.shape;
We see that the shape [null,7,7,256]
Now we can input this to our dense layers trainableModel = tf.sequential({
layers: [
tf.layers.flatten({inputShape: [7, 7, 256]}),
tf.layers.dense({
units: 100,
activation: 'relu',
kernelInitializer: 'varianceScaling',
useBias: true
}),
tf.layers.dense({
units: 2,
kernelInitializer: 'varianceScaling',
useBias: false,
activation: 'softmax'
})
]
});
As you can see we created a dense layer with 100
neurons and the output layer with size 2
.const activation = mobilenet.predict(input);
const predictions = trainableModel.predict(activation);
And we can use the previous sections to train the last model using a certain optimizer.
4月 03, 2018
—
Posted by Zaid Alyafeai
Tensorflow.js is a library built on deeplearn.js to create deep learning modules directly on the browser. Using that you can create CNNs, RNNs , etc … on the browser and train these modules using the client’s GPU processing power. Hence, a server GPU is not needed to train the NN. This tutorial starts by explaining the basic building blocks of TensorFlow.js and the operatio…