liepos 02, 2018 —
                                          
Posted by Zaid Alyafeai
We will create a simple tool that recognizes drawings and outputs the names of the current drawing. This app will run directly on the browser without any installations. We will use Google Colab for training the model, and we will deploy it on the browser using TensorFlow.js.
Code and DemoFind the live demo and the code on GitHub. Also make sure to test the notebook on Googl…

|  | 
| A subset of the classes | 
|  | 
| The pipeline | 
import os
import glob
import numpy as np
from tensorflow.keras import layers
from tensorflow import keras 
import tensorflow as tf[N,784] where N is the number of of the images for that particular class. We first download the datasetimport urllib.request
def download():
  
  base = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'
  for c in classes:
    cls_url = c.replace('_', '%20')
    path = base+cls_url+'.npy'
    print(path)
    urllib.request.urlretrieve(path, 'data/'+c+'.npy')def load_data(root, vfold_ratio=0.2, max_items_per_class= 5000 ):
    all_files = glob.glob(os.path.join(root, '*.npy'))
    #initialize variables 
    x = np.empty([0, 784])
    y = np.empty([0])
    class_names = []
    #load a subset of the data to memory 
    for idx, file in enumerate(all_files):
        data = np.load(file)
        data = data[0: max_items_per_class, :]
        labels = np.full(data.shape[0], idx)
        x = np.concatenate((x, data), axis=0)
        y = np.append(y, labels)
        class_name, ext = os.path.splitext(os.path.basename(file))
        class_names.append(class_name)
    data = None
    labels = None
    #separate into training and testing 
    permutation = np.random.permutation(y.shape[0])
    x = x[permutation, :]
    y = y[permutation]
    vfold_size = int(x.shape[0]/100*(vfold_ratio*100))
    x_test = x[0:vfold_size, :]
    y_test = y[0:vfold_size]
    x_train = x[vfold_size:x.shape[0], :]
    y_train = y[vfold_size:y.shape[0]]
    return x_train, y_train, x_test, y_test, class_names[N, 28, 28, 1] and outputs probabilities of the shape [N, 100]# Reshape and normalize
x_train = x_train.reshape(x_train.shape[0], image_size, image_size, 1).astype('float32')
x_test = x_test.reshape(x_test.shape[0], image_size, image_size, 1).astype('float32')
x_train /= 255.0
x_test /= 255.0
# Convert class vectors to class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)# Define model
model = keras.Sequential()
model.add(layers.Convolution2D(16, (3, 3),
                        padding='same',
                        input_shape=x_train.shape[1:], activation='relu'))
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
model.add(layers.Convolution2D(32, (3, 3), padding='same', activation= 'relu'))
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
model.add(layers.Convolution2D(64, (3, 3), padding='same', activation= 'relu'))
model.add(layers.MaxPooling2D(pool_size =(2,2)))
model.add(layers.Flatten())
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(100, activation='softmax')) 
# Train model
adam = tf.train.AdamOptimizer()
model.compile(loss='categorical_crossentropy',
              optimizer=adam,
              metrics=['top_k_categorical_accuracy'])
print(model.summary())5 epochs and 256 batches with 10% validation split#fit the model 
model.fit(x = x_train, y = y_train, validation_split=0.1, batch_size = 256, verbose=2, epochs=5)
#evaluate on unseen data
score = model.evaluate(x_test, y_test, verbose=0)
print('Test accuarcy: {:0.2f}%'.format(score[1] * 100))
92.20% top 5 accuracy.model.save('keras.h5')!pip install tensorflowjs !mkdir model
!tensorflowjs_converter --input_format keras keras.h5 model/!zip -r model.zip model from google.colab import files
files.download('model.zip')300 x 300. I will not go over the details of the interface and focus on TensorFlow.js part.<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>model = await tf.loadLayersModel('model/model.json')await keyword waits for the model to be loaded by the browser.//the minimum boudning box around the current drawing
const mbb = getMinBox()
//cacluate the dpi of the current window 
const dpi = window.devicePixelRatio
//extract the image data 
const imgData = canvas.contextContainer.getImageData(mbb.min.x * dpi, mbb.min.y * dpi,
             (mbb.max.x - mbb.min.x) * dpi, (mbb.max.y - mbb.min.y) * dpi);getMinBox() will be explained later. The variable dpi is used to stretch the crop of the canvas according to the density of the pixels of the screen.function preprocess(imgData)
{
return tf.tidy(()=>{
    //convert the image data to a tensor 
    let tensor = tf.browser.fromPixels(imgData, numChannels= 1)
    //resize to 28 x 28 
    const resized = tf.image.resizeBilinear(tensor, [28, 28]).toFloat()
    // Normalize the image 
    const offset = tf.scalar(255.0);
    const normalized = tf.scalar(1.0).sub(resized.div(offset));
    //We add a dimension to get a batch shape 
    const batched = normalized.expandDims(0)
    return batched
})
}model.predict this will return probabilities of the shape [N, 100]const pred = model.predict(preprocess(imgData)).dataSync()[N, 28, 28,1] . The drawing canvas we have is of size 300 x 300 which might be two large for drawings or the user might draw a small figure. It will be better to crop only the box that contains the current drawing. To do that we extract the minimum bounding box around the drawing by finding the top left and the bottom right points//record the current drawing coordinates    
function recordCoor(event)
{
  //get current mouse coordinate 
  var pointer = canvas.getPointer(event.e);
  var posX = pointer.x;
  var posY = pointer.y;
  
  //record the point if withing the canvas and the mouse is pressed 
  if(posX >=0 && posY >= 0 && mousePressed)  
  {   
    coords.push(pointer) 
  } 
}
   
//get the best bounding box by finding the top left and bottom right cornders    
function getMinBox(){
 
   var coorX = coords.map(function(p) {return p.x});
   var coorY = coords.map(function(p) {return p.y});
   //find top left corner 
   var min_coords = {
    x : Math.min.apply(null, coorX),
    y : Math.min.apply(null, coorY)
   }
   //find right bottom corner 
   var max_coords = {
    x : Math.max.apply(null, coorX),
    y : Math.max.apply(null, coorY)
   }
   return {
    min : min_coords,
    max : max_coords
   }
}
 
liepos 02, 2018
 —
                                  
Posted by Zaid Alyafeai
We will create a simple tool that recognizes drawings and outputs the names of the current drawing. This app will run directly on the browser without any installations. We will use Google Colab for training the model, and we will deploy it on the browser using TensorFlow.js.
Code and DemoFind the live demo and the code on GitHub. Also make sure to test the notebook on Googl…