November 09, 2021 — Posted by the TensorFlow Lite team TensorFlow Lite is Google’s machine learning framework to deploy machine learning models on multiple devices and surfaces such as mobile (iOS and Android), desktops and other edge devices. Recently, we added support to run TensorFlow Lite models in a browser as well. In order to build apps using TensorFlow Lite, you can either use an off-the shelf model from Tens…
Posted by the TensorFlow Lite team
TensorFlow Lite is Google’s machine learning framework to deploy machine learning models on multiple devices and surfaces such as mobile (iOS and Android), desktops and other edge devices. Recently, we added support to run TensorFlow Lite models in a browser as well. In order to build apps using TensorFlow Lite, you can either use an off-the shelf model from TensorFlow Hub, or convert an existing TensorFlow Model to a TensorFlow Lite model using the converter. Once the model is deployed in an app, you can run inference on the model based on input data.
TensorFlow Lite now supports training your models on-device, in addition to running inference. On-device training enables interesting personalization use cases where models can be fine-tuned based on user needs. For instance, you could deploy an image classification model and allow a user to fine-tune the model to recognize bird species using transfer learning, while allowing another user to retrain the same model to recognize fruits. This new feature is available in TensorFlow 2.7 and later and is currently available for Android apps. (iOS support will be added in the future.)
On-device training is also a necessary foundation for Federated Learning use cases to train global models on decentralized data. This blog post does not cover Federated Learning and instead focuses on helping you integrate on-device training in your Android apps.
Later in this article we will reference a Colab and Android sample app as we walk you through the end-to-end implementation path for on-device learning to fine-tune an image classification model.
In our 2019 blog post, we introduced on-device training concepts and an example of on-device training in TensorFlow Lite. However, there were several limitations. For example, it was not easy to customize the model structure and optimizers. You also had to deal with multiple physical TensorFlow Lite (.tflite) models instead of a single TensorFlow Lite model. Similarly, there was no easy way to store and update the training weights. Our latest TensorFlow Lite version streamlines this process by providing more convenient options for on-device training, as explained below.
In order to deploy a TensorFlow Lite model with on-device training built-in, here are the high level steps:
These steps are explained below.
The TensorFlow Lite model should not only support model inference, but also model training, which typically involves saving the model’s weights to the file system and restoring the weights from the file system. This is done to save the training weights after each training epoch, so that the next training epoch can use the weights from the previous one, instead of starting training from scratch.
Our suggested approach is to implement these tf.functions to represent training, inference, saving weights, and loading weights:
# The `train` function takes a batch of input images and labels.
@tf.function(input_signature=[
tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
tf.TensorSpec([None, 10], tf.float32),
])
def train(self, x, y):
with tf.GradientTape() as tape:
prediction = self.model(x)
loss = self._LOSS_FN(prediction, y)
gradients = tape.gradient(loss, self.model.trainable_variables)
self._OPTIM.apply_gradients(
zip(gradients, self.model.trainable_variables))
result = {"loss": loss}
for grad in gradients:
result[grad.name] = grad
return result
@tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
def predict(self, x):
return {
"output": self.model(x)
}
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def save(self, checkpoint_path):
tensor_names = [weight.name for weight in self.model.weights]
tensors_to_save = [weight.read_value() for weight in self.model.weights]
tf.raw_ops.Save(
filename=checkpoint_path, tensor_names=tensor_names,
data=tensors_to_save, name='save')
return {
"checkpoint_path": checkpoint_path
}
You may already be familiar with the workflow to convert your TensorFlow model to the TensorFlow Lite format. Some of the low level features for on-device training (e.g., variables to store the model parameters) are still experimental, and others (e.g., weight serialization) currently rely on TF Select operators, so you will need to set these flags during conversion. You can find an example of all the flags you need to set in the Colab.
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()
Once you have converted your model to the TensorFlow Lite format, you’re ready to integrate the model into your app! Refer to the Android app samples for more details.
On Android, TensorFlow Lite on-device training can be performed using either Java or C++ APIs. You can create an instance of the TensorFlow Lite Interpreter to load a model and drive model training tasks. We had previously defined multiple tf.functions: these functions can be invoked using TensorFlow Lite’s support for Signatures, which allow a single TensorFlow Lite model to support multiple ‘entry’ points. For example, we had defined a train function for on-device training, which is one of the model’s signatures. The train function can be invoked using TensorFlow Lite’s runSignature method by specifying the name of the signature (‘train’):
// Run training for a few steps.
float[] losses = new float[NUM_EPOCHS];
for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
Map<String, Object> inputs = new HashMap<>>();
inputs.put("x", trainImageBatches.get(batchIdx));
inputs.put("y", trainLabelBatches.get(batchIdx));
Map<String, Object> outputs = new HashMap<>();
FloatBuffer loss = FloatBuffer.allocate(1);
outputs.put("loss", loss);
interpreter.runSignature(inputs, outputs, "train");
// Record the last loss.
if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);
}
}
Similarly, the following example shows how to invoke inference using the model’s ‘infer’ signature:
try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
// Restore the weights from the checkpoint file.
int NUM_TESTS = 10;
FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());
// Fill the test data.
// Run the inference.
Map<String, Object> inputs = new HashMap<>>();
inputs.put("x", testImages.rewind());
Map<String, Object> outputs = new HashMap<>();
outputs.put("output", output);
anotherInterpreter.runSignature(inputs, outputs, "infer");
output.rewind();
// Process the result to get the final category values.
int[] testLabels = new int[NUM_TESTS];
for (int i = 0; i < NUM_TESTS; ++i) {
int index = 0;
for (int j = 1; j < 10; ++j) {
if (output.get(i * 10 + index) < output.get(i * 10 + j))
index = testLabels[j];
}
testLabels[i] = index;
}
}
And, that’s it! You now have a TensorFlow Lite model that is able to use on-device training. We hope that this code walkthrough gives you a good idea on how to run on-device training in TensorFlow Lite, and we’re excited to see where you take it.
In theory, you should be able to apply on-device training in TensorFlow Lite to any use case that TensorFlow supports. However, in reality there are a few practical considerations that you need to keep in mind before you deploy on-device training in your apps:
Future work includes (but is not limited to) on-device training support on iOS, performance improvements to leverage on-device accelerators (e.g. GPUs) for on-device training, reducing the binary size by implementing more training ops natively in TensorFlow Lite, higher level API support (e.g. via the TensorFlow Lite Task Library) to abstract away the implementation details and examples covering other on-device training use cases (e.g. NLP). Our long term roadmap involves potentially providing on-device end-to-end Federated Learning solutions.
Thank you for reading! We’re excited to see what you build using on-device learning. Once again, here are links to the sample app and Colab. If you have any feedback, please let us know on the TensorFlow Forum, or on GitHub.
This post reflects the significant contributions of many people in Google’s TensorFlow Lite team including Michelle Carney, Lawrence Chan, Jaesung Chung, Jared Duke, Terry Heo, Jared Lim, Yu-Cheng Ling, Thai Nguyen, Karim Nosseir, Arun Venkatesan, Haoliang Zhang, other TensorFlow Lite team members, and our collaborators in Google Research.
November 09, 2021 — Posted by the TensorFlow Lite team TensorFlow Lite is Google’s machine learning framework to deploy machine learning models on multiple devices and surfaces such as mobile (iOS and Android), desktops and other edge devices. Recently, we added support to run TensorFlow Lite models in a browser as well. In order to build apps using TensorFlow Lite, you can either use an off-the shelf model from Tens…