Training and serving a realtime mobile object detector in 30 minutes with Cloud TPUs
7月 13, 2018
Posted by Sara Robinson, Aakanksha Chowdhery, and Jonathan Huang

What if you could train and serve your object detection models even faster? We’ve heard your feedback, and today we’re excited to announce support for training an object detection model on Cloud TPUs, model quantization, and the addition of new models including RetinaNet and a MobileNet adaptation of RetinaNet. You can check out the announcement post on the AI blog. In this post, we’ll walk you through training a quantized pet breed detector on Cloud TPUs using transfer learning.
pet breed detector
The whole process — from training to on-device inference on Android — takes 30 minutes and costs less than $5 on Google Cloud. When you’re done, you’ll have an Android app (iOS tutorial coming soon!) that performs real-time detection of dog and cat breeds and requires no more than 12Mb of space on your phone. Note that in addition to training an object detection model in the cloud, you can alternatively run training on your own hardware or in Colab.

Setting up your environment

We will first set up some of the libraries and prerequisites needed for training and serving our model. Note that this setup process may take significantly longer than training and serving the model itself. For convenience, you may use the Dockerfile here that provides the dependencies for installing Tensorflow from source and downloading the necessary datasets and models for this tutorial. If you decide to use Docker, you should still work through the “Google Cloud Setup” section and then skip to the “Uploading dataset to GCS” section. The Dockerfile will also build the Android dependencies for the Tensorflow Lite section. See the attached README file for more information.

Google Cloud Setup

First, create a project in the Google Cloud Console and enable billing for that project. We’ll use Cloud Machine Learning Engine to run our training job on Cloud TPUs. ML Engine is Google Cloud’s managed platform for TensorFlow, and it simplifies the process of training and serving ML models. To use it, enable the necessary APIs for the project you just created.

Second, we’ll create a Google Cloud Storage bucket to store the training and test data for our model, along with the model checkpoints from our training job.

Note that all of the commands in this tutorial assume you’re running Ubuntu. We’ll use Google Cloud gcloud CLI for many of the commands in this tutorial, along with the Cloud Storage gsutil CLI to interact with our GCS buckets. If you don’t have those installed, you can install gcloud here and gsutil here.

Run the following to set your current project to the one you just created, replacing YOUR_PROJECT_NAME with the name of your project:
gcloud config set project YOUR_PROJECT_NAME
Then we’ll create a Cloud Storage bucket with the following command. Note that Storage bucket names must be globally unique, so you may get an error if the first name you choose is taken.
This may prompt you to first run gcloud auth login, after which you will need to provide a verification code sent to your browser.

Then set up two environment variables to simplify commands throughout this tutorial:
Next, to give our Cloud TPU access to our project we need to add a TPU-specific service account. First, get the name of your service account with the following command:
curl -H "Authorization: Bearer $(gcloud auth print-access-token)"  \${PROJECT}:getConfig
When this command completes, copy the value of tpuServiceAccount (it will look something like and then save it as an environment variable:
export TPU_ACCOUNT=your-service-account
Finally, grant the ml.serviceAgent role to your TPU service account:
gcloud projects add-iam-policy-binding $PROJECT  \
    --member serviceAccount:$TPU_ACCOUNT --role roles/ml.serviceAgent

Installing Tensorflow

If you don’t have TensorFlow installed, follow the steps here. To follow the on-device section of this tutorial, you’ll need to install TensorFlow from source using Bazel following the instructions here. Compiling TensorFlow may take a while. If you’d just like to follow the Cloud TPU training section of this tutorial, you don’t need to compile TensorFlow from source and can install a released version via pip, Anaconda, etc.

Installing TensorFlow Object Detection

If this is your first time using TensorFlow Object Detection, welcome! To install it, follow the instructions here.

Once you’ve installed Object Detection, be sure to test your installation by running the following:
python object_detection/builders/
If installation is successful, you should see the following output:
Ran 18 tests in 0.079s


Setting up the dataset

To keep things simple, we’ll use the same pet breeds dataset from our last post on training an object detection model. This dataset includes around 7,400 images — ~200 images for 37 different cat and dog breeds. Each image has an associated annotations file, which includes the bounding box coordinates where the specific pet is located in the image. We can’t feed these images and annotations directly to our model; we need to convert them into a format our model can understand. For this we’ll use the TFRecord format.

To dive right in to training, we’ve made the pet_faces_train.record and pet_faces_val.record files publicly accessible here. You can either use the public TFRecord files, or if you’d like to generate them yourself, follow the steps here.

You can download and extract the public TFRecord files using the following command:
mkdir /tmp/pet_faces_tfrecord/
cd /tmp/pet_faces_tfrecord/
curl "" | tar xzf -
Note that these TFRecord files are sharded, so once you’ve extract them you’ll have 10 pet_faces_train.record files and 10 pet_faces_val.record files.

Uploading dataset to GCS

Once you’ve got your TFRecord files available locally, copy them into your GCS bucket under a /data subdirectory:
gsutil -m cp -r /tmp/pet_faces_tfrecord/pet_faces* gs://${YOUR_GCS_BUCKET}/data/
With your TFRecord files in GCS, move back to the models/research directory on your local machine. Next you’ll add the pet_label_map.pbtxt file in your GCS bucket. This maps each of the 37 pet breeds we’ll be detecting to an integer, so that our model can understand them in a numerical format. From the models/research directory, run the following:
gsutil cp object_detection/data/pet_label_map.pbtxt gs://${YOUR_GCS_BUCKET}/data/pet_label_map.pbtxt
At this point you should have 21 files in the /data subdirectory of your GCS bucket: the 20 sharded TFRecord files for training and testing, and the label map file.

Using the SSD MobileNet checkpoint for transfer learning

Training a model to recognize pet breeds from scratch would take thousands of training images for each pet breed and hours or days of training time. To speed this up, we can make use of transfer learning — a process where we take the weights of a model that has already been trained on lots of data to perform a similar task, and then train the model on our own data, fine tuning the layers from the pre-trained model.

There are many models we can use that have been trained to recognize a wide variety of objects in images. We can use the checkpoints from these trained models and then apply them to our custom object detection task. This works because, to a machine, the task of identifying the pixels in an image that contain basic objects like tables, chairs, or cats isn’t so different from identifying the pixels in an image that contain specific pet breeds.

For this example we’ll use SSD with MobileNet, an object detection model optimized for inference on mobile. First, download and extract the latest MobileNet checkpoint that’s been pretrained on the COCO dataset. To see a list of all the models that the Object Detection API supports, check out the model zoo. Once you’ve extracted the checkpoint, copy the 3 files into your GCS bucket. Run the commands below to download the checkpoint and copy it into your bucket:
cd /tmp
curl -O
tar xzf ssd_mobilenet_v1_0.75_depth_300x300_coco14_sync_2018_07_03.tar.gz

gsutil cp /tmp/ssd_mobilenet_v1_0.75_depth_300x300_coco14_sync_2018_07_03/model.ckpt.* gs://${YOUR_GCS_BUCKET}/data/
When we train our model, it’ll use these checkpoints as its starting point for training. Now you should have 24 files in your GCS bucket. We’re almost ready to run our training job, but we need a way to tell ML Engine where our data and model checkpoints are located. We’ll do this with a config file, which we’ll set up in the next step. Our config file provides hyperparameters for our model, the file paths for our training data, test data, and the initial model checkpoint.

Training a quantized model with Cloud TPUs on Cloud ML Engine

Machine learning models have two distinct computational components: training and inference. In this example, we’re making use of Cloud TPUs to accelerate training. There are a few lines in the config file that relate specifically to TPU training. We can use a larger batch size when training on TPUs since they make it easier to handle large datasets (when experimenting with batch size on your own dataset, make sure to use multiples of 8 since data needs to be divided evenly for each of the 8 TPU cores). With a larger batch size for our model, we can reduce the number of training steps (in this example we use 2000). The focal loss function we use for this training job, defined in the following lines in the config, is also a great fit for TPUs:
loss {
  classification_loss {
    weighted_sigmoid_focal {
      alpha: 0.75,
      gamma: 2.0
This loss function computes loss for every example in the dataset and then reweights them, assigning more relative weight to hard, misclassified examples. This logic is better suited for TPUs than the hard example mining operation used in other training jobs. You can read more about focal loss in Lin et al. (2017).

Recall from above that the process of initializing a pre-trained model checkpoint and then adding our own training data is called transfer learning. The following lines in the config tell our model that we’ll be doing transfer learning
fine_tune_checkpoint: "gs://your-bucket/data/model.ckpt"
fine_tune_checkpoint_type: "detection"
We also need to consider how our model will be used after it’s been trained. Let’s say our pet detector becomes a global hit, used by animal lovers and pet stores everywhere. We need a scalable way to handle these inference requests with low latency. The output of a machine learning model is a binary file containing the trained weights of our model — these files are often quite large, but since we’ll be serving this model directly on a mobile device we’ll need to make it as small as possible.

This is where model quantization comes in. Quantization compresses the weights and activations in our model to an 8-bit fixed point representation. The following lines in our config file will generate a quantized model:
graph_rewriter {
  quantization {
    delay: 1800
    activation_bits: 8
    weight_bits: 8
Typically with quantization, a model will train with full precision for a certain number of steps before switching to quantized training. The delay number above tells ML Engine to begin quantizing our weights and activations after 1800 training steps.

To tell ML Engine where to find our training and test files and model checkpoint, you’ll need to update a few lines in the config file we’ve created for you to point to your bucket. From the research directory, find object_detection/samples/configs/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config. Update all the PATH_TO_BE_CONFIGURED strings with the full path of the data directory in your GCS bucket. For example, the train_input_reader section of the config would look like the following (make sure to replace YOUR_GCS_BUCKET with the name of your bucket):
train_input_reader: {
  tf_record_input_reader {
    input_path: "gs://YOUR_GCS_BUCKET/data/pet_faces_train*"
  label_map_path: "gs://YOUR_GCS_BUCKET/data/pet_label_map.pbtxt"
Then copy this quantized config file into your GCS bucket:
gsutil cp object_detection/samples/configs/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config gs://${YOUR_GCS_BUCKET}/data/pipeline.config
Before we kick off our training job on Cloud ML Engine, we need to package the Object Detection API, pycocotools, and TF Slim. We can do that with the following command (run this from the research/ directory, and note that the parentheses are part of the command):
bash object_detection/dataset_tools/ /tmp/pycocotools
python sdist
(cd slim && python sdist)
We’re ready to train our model! To kick off training, run the following gcloud command:
gcloud ml-engine jobs submit training `whoami`_object_detection_`date +%s` \
--job-dir=gs://${YOUR_GCS_BUCKET}/train \
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz,/tmp/pycocotools/pycocotools-2.0.tar.gz \
--module-name object_detection.model_tpu_main \
--runtime-version 1.8 \
--scale-tier BASIC_TPU \
--region us-central1 \
-- \
--model_dir=gs://${YOUR_GCS_BUCKET}/train \
--tpu_zone us-central1 \
Note that if you receive an error saying that no Cloud TPUs are available, we recommend simply trying again in a different zone (Cloud TPUs are currently available in us-central1-b, us-central1-c, europe-west4-a, and asia-east1-c).

Right after we kick off our training job, run the following command to start an evaluation job:
gcloud ml-engine jobs submit training `whoami`_object_detection_eval_validation_`date +%s` \
--job-dir=gs://${YOUR_GCS_BUCKET}/train \
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz,/tmp/pycocotools/pycocotools-2.0.tar.gz \
--module-name object_detection.model_main \
--runtime-version 1.8 \
--scale-tier BASIC_GPU \
--region us-central1 \
-- \
--model_dir=gs://${YOUR_GCS_BUCKET}/train \
--pipeline_config_path=gs://${YOUR_GCS_BUCKET}/data/pipeline.config \
Both training and evaluation should complete within about 30 minutes. While they are running, you can use TensorBoard to see the accuracy of your model. To start TensorBoard, run the following:
tensorboard --logdir=gs://${YOUR_GCS_BUCKET}/train
Note that you may need to first run gcloud auth application-default login.

Navigate to localhost:6006 to view your TensorBoard output. Here you’ll see some common ML metrics used to analyze the accuracy of your model. Note that these graphs only have 2 points plotted since the model trains quickly in very few steps (if you’ve used TensorBoard before you may be used to seeing more of a curve here). The first point here is early in the training process and the last point shows metrics at the last step (step 2000).

First, let’s look at the graph for mean average precision at 0.5 IOU (mAP@.50IOU):
screenshot of TensorBoard
Mean average precision measures our model’s percentage of correct predictions for all 37 labels. IoU is specific to object detection models and stands for Intersection-over-Union. This measures the overlap between the bounding box generated by our model and the ground truth bounding box, represented as a percentage. This graph is measuring the percentage of correct bounding boxes and labels our model returned, with “correct” in this case referring to bounding boxes that had 50% or more overlap with their corresponding ground truth boxes. After training, our model achieved 82% mean average precision.

Next, look at the Images tab in TensorBoard:
images tab in TensorBoard
In the left image, we see our model’s predictions for this image and on the right we see the correct, ground truth box. The bounding box is very accurate, but our model’s label prediction is incorrect in this particular case. No ML model can be perfect. 😉

Running on mobile with TensorFlow Lite

At this point you have a fully trained pet detector, which you can use to test your own images in the browser with zero setup using this Colab notebook. To run this model in real time on a phone requires some extra work — — in this section, we will show you how to use TensorFlow Lite to get a smaller model and allow you take advantage of ops that have been optimized for mobile devices. TensorFlow Lite is TensorFlow’s lightweight solution for mobile and embedded devices. It enables on-device machine learning inference with low latency and a small binary size. TensorFlow Lite uses many techniques for this such as quantized kernels that allow smaller and faster (fixed-point math) models.

As mentioned above, for this section, you will need to use the provided Dockerfile, or build TensorFlow from source (with GCP support) and install the bazel build tool. Note that if you’d like to only work through this second part of the tutorial without training a model, we’ve made a pre-trained model available here.

To make these commands easier to run, let’s set up some environment variables:
export CONFIG_FILE=gs://${YOUR_GCS_BUCKET}/data/pipeline.config
export CHECKPOINT_PATH=gs://${YOUR_GCS_BUCKET}/train/model.ckpt-2000
export OUTPUT_DIR=/tmp/tflite
We start by getting a TensorFlow frozen graph with compatible ops that we can use with TensorFlow Lite. First, you’ll need to install these python libraries. Then to get the frozen graph, run the script from the models/research directory with this command:
python object_detection/ \
--pipeline_config_path=$CONFIG_FILE \
--trained_checkpoint_prefix=$CHECKPOINT_PATH \
--output_directory=$OUTPUT_DIR \
In the /tmp/tflite directory, you should now see two files: tflite_graph.pb and tflite_graph.pbtxt (sample frozen graphs are here). Note that the add_postprocessing flag enables the model to take advantage of a custom optimized detection post-processing operation which can be thought of as a replacement for tf.image.non_max_suppression. Make sure not to confuse export_tflite_ssd_graph with export_inference_graph in the same directory. Both scripts output frozen graphs: export_tflite_ssd_graph will output the frozen graph that we can input to TensorFlow Lite directly and is the one we’ll be using.

Next we’ll use TensorFlow Lite to get the optimized model by using TOCO, the TensorFlow Lite Optimizing Converter. This will convert the resulting frozen graph (tflite_graph.pb) to the TensorFlow Lite flatbuffer format (detect.tflite) via the following command. Run this from the tensorflow/ directory:
bazel run -c opt tensorflow/contrib/lite/toco:toco -- \
--input_file=$OUTPUT_DIR/tflite_graph.pb \
--output_file=$OUTPUT_DIR/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'  \
--inference_type=QUANTIZED_UINT8 \
--mean_values=128 \
--std_values=128 \
--change_concat_input_ranges=false \
This command takes the input tensor normalized_input_image_tensor after resizing each camera image frame to 300x300 pixels. The outputs of the quantized model are named 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1', 'TFLite_Detection_PostProcess:2', and 'TFLite_Detection_PostProcess:3' and represent four arrays: detection_boxes, detection_classes, detection_scores, and num_detections. The documentation for other flags used in this command is here. If things ran successfully, you should now see a third file in the /tmp/tflite directory called detect.tflite (sample tflite file is here). This file contains the graph and all model parameters and can be run via the TensorFlow Lite interpreter on the Android device and should be less than 4 Mb in size.

Running our model on Android

To run our final model on device, we will need to use the provided Dockerfile, or install the Android NDK and SDK. The current recommended Android NDK version is 14b and can be found on the NDK Archives page. Please note that current version of Bazel is incompatible with NDK revisions 15 and above. Android SDK and build tools can be downloaded separately or used as part of Android Studio. To build the TensorFlow Lite Android demo, build tools require API >= 23 (but it will run on devices with API >= 21). Additional details are available on the TensorFlow Lite Android App page.

Before trying to get the pets model that you just trained, start by running the demo app with its default model, which was trained on the COCO dataset. To build the demo app, run this bazel command from the tensorflow directory:
bazel build -c opt --config=android_arm{,64} --cxxopt='--std=c++11' \
The apk above will be built for 64-bit architecture and you may replace it with-- config=android_arm for 32-bit support. Now install the demo on a debug-enabled Android phone via Android Debug Bridge (adb):
adb install bazel-bin/tensorflow/contrib/lite/examples/android/tflite_demo.apk
Try running this starter app (called TFLDetect) and holding your camera to people, furniture, cars, pets, etc. The working test app should look something like this. You will see boxes around the detected objects with their labels. The working test app was trained using the COCO dataset.

Once you’ve got the generic detector working, replacing it with your custom pet detector is fairly simple. All we need to do is point the app to our new detect.tflite file and give it the names of our new labels. Specifically, we will copy our TensorFlow Lite flatbuffer to the app assets directory with the following command:
cp /tmp/tflite/detect.tflite \
We will now edit the BUILD file to point to this new model. First, open the BUILD file tensorflow/contrib/lite/examples/android/BUILD. Then find the assets section, and replace the line "@tflite_mobilenet_ssd_quant//:detect.tflite" (which by default points to a COCO pretrained model) with the path to your TFLite pets model “//tensorflow/contrib/lite/examples/android/app/src/main/assets:detect.tflite”. Finally, change the last line in assets section to use the new label map. Your final assets section should look like this:
assets = [
We will also need to tell our app to use the new label map. In order to do this, open up the tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ file in a text editor and find the definition of TF_OD_API_LABELS_FILE. Update this path to point to your pets label map file: “file:///android_asset/pets_labels_list.txt”. Note that we have already made the pets_labels_list.txt file available for your convenience. This new section of (around line 50) should now look as follows:
// Configuration values for the prepackaged SSD model.
private static final int TF_OD_API_INPUT_SIZE = 300;
private static final boolean TF_OD_API_IS_QUANTIZED = true;
private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/pets_labels_list.txt";
Once you’ve copied the TensorFlow Lite file and edited your BUILD and files, rebuild and reinstall your app with the following commands:
bazel build -c opt --config=android_arm{,64} --cxxopt='--std=c++11' \
adb install -r bazel-bin/tensorflow/contrib/lite/examples/android/tflite_demo.apk
Now for the best part: find the nearest dog or cat and try detecting it. On a Pixel 2, we get greater than 15 frames per second.

Get started

Want to use your own training data to train an object detection model on Cloud TPUs? Dive into the object detection docs here. For labeling your own image library, check out these resources for generating TFRecord files for other well known datasets. We’d love to have you contribute and hear your feedback. Leave a comment on this post, or submit a PR or issue on GitHub.

Thank you to everyone who worked on this release: Derek Chow, Aakanksha Chowdhery, Jonathan Huang, Zhichao Lu, Vivek Rathod, Ronny Votel, Pengchong Jin, Xiangxin Zhu as well as the following colleagues for their guidance and advice: Vasu Agrawal, Sourabh Bajaj, Chiachen Chou, Tom Jablin, Wenzhe Li, Tsung-Yi Lin, Hernan Moraldo, Kevin Murphy, Sara Robinson, Andrew Selle, Shashi Shekhar, Yash Sonthalia, Zak Stone, Pete Warden, Menglong Zhu.
Next post
Training and serving a realtime mobile object detector in 30 minutes with Cloud TPUs

Posted by Sara Robinson, Aakanksha Chowdhery, and Jonathan Huang

What if you could train and serve your object detection models even faster? We’ve heard your feedback, and today we’re excited to announce support for training an object detection model on Cloud TPUs, model quantization, and the addition of new models including RetinaNet and a MobileNet adaptation of RetinaNet. You can check out the…