On-device training in TensorFlow Lite
November 09, 2021

Posted by the TensorFlow Lite team

TensorFlow Lite is Google’s machine learning framework to deploy machine learning models on multiple devices and surfaces such as mobile (iOS and Android), desktops and other edge devices. Recently, we added support to run TensorFlow Lite models in a browser as well. In order to build apps using TensorFlow Lite, you can either use an off-the shelf model from TensorFlow Hub, or convert an existing TensorFlow Model to a TensorFlow Lite model using the converter. Once the model is deployed in an app, you can run inference on the model based on input data.

TensorFlow Lite now supports training your models on-device, in addition to running inference. On-device training enables interesting personalization use cases where models can be fine-tuned based on user needs. For instance, you could deploy an image classification model and allow a user to fine-tune the model to recognize bird species using transfer learning, while allowing another user to retrain the same model to recognize fruits. This new feature is available in TensorFlow 2.7 and later and is currently available for Android apps. (iOS support will be added in the future.)

On-device training is also a necessary foundation for Federated Learning use cases to train global models on decentralized data. This blog post does not cover Federated Learning and instead focuses on helping you integrate on-device training in your Android apps.

Later in this article we will reference a Colab and Android sample app as we walk you through the end-to-end implementation path for on-device learning to fine-tune an image classification model.

Improvements over the earlier approach

In our 2019 blog post, we introduced on-device training concepts and an example of on-device training in TensorFlow Lite. However, there were several limitations. For example, it was not easy to customize the model structure and optimizers. You also had to deal with multiple physical TensorFlow Lite (.tflite) models instead of a single TensorFlow Lite model. Similarly, there was no easy way to store and update the training weights. Our latest TensorFlow Lite version streamlines this process by providing more convenient options for on-device training, as explained below.

How does it work?

In order to deploy a TensorFlow Lite model with on-device training built-in, here are the high level steps:

  • Build a TensorFlow model for training and inference
  • Convert the TensorFlow model to TensorFlow Lite format
  • Integrate the model in your Android app
  • Invoke model training in the app, similar to how you would invoke model inference

These steps are explained below.

Build a TensorFlow model for training and inference

The TensorFlow Lite model should not only support model inference, but also model training, which typically involves saving the model’s weights to the file system and restoring the weights from the file system. This is done to save the training weights after each training epoch, so that the next training epoch can use the weights from the previous one, instead of starting training from scratch.

Our suggested approach is to implement these tf.functions to represent training, inference, saving weights, and loading weights:

  • A train function that trains the model using training data. The train function below makes a prediction, calculates the loss (or error), and uses tf.GradientTape() to record operations for automatic differentiation and update the model’s parameters.
    # The `train` function takes a batch of input images and labels.
    @tf.function(input_signature=[
         tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
         tf.TensorSpec([None, 10], tf.float32),
     ])
    def train(self, x, y):
       with tf.GradientTape() as tape:
         prediction = self.model(x)
         loss = self._LOSS_FN(prediction, y)
       gradients = tape.gradient(loss, self.model.trainable_variables)
       self._OPTIM.apply_gradients(
           zip(gradients, self.model.trainable_variables))
       result = {"loss": loss}
       for grad in gradients:
         result[grad.name] = grad
       return result
    
  • An infer or a predict function that invokes model inference. This is similar to how you currently use TensorFlow Lite for inference.
    @tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
     def predict(self, x):
       return {
           "output": self.model(x)
       }
    
  • A save/restore function that saves training weights (i.e., parameters used by the model) in Checkpoints format to the file system. The save function’s code is shown below.
    @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
     def save(self, checkpoint_path):
       tensor_names = [weight.name for weight in self.model.weights]
       tensors_to_save = [weight.read_value() for weight in self.model.weights]
       tf.raw_ops.Save(
           filename=checkpoint_path, tensor_names=tensor_names,
           data=tensors_to_save, name='save')
       return {
           "checkpoint_path": checkpoint_path
       }
    

Convert to TensorFlow Lite format

You may already be familiar with the workflow to convert your TensorFlow model to the TensorFlow Lite format. Some of the low level features for on-device training (e.g., variables to store the model parameters) are still experimental, and others (e.g., weight serialization) currently rely on TF Select operators, so you will need to set these flags during conversion. You can find an example of all the flags you need to set in the Colab.

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
   tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
   tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()

Integrate the model in your Android app

Once you have converted your model to the TensorFlow Lite format, you’re ready to integrate the model into your app! Refer to the Android app samples for more details.

Invoke model training and inference in app

On Android, TensorFlow Lite on-device training can be performed using either Java or C++ APIs. You can create an instance of the TensorFlow Lite Interpreter to load a model and drive model training tasks. We had previously defined multiple tf.functions: these functions can be invoked using TensorFlow Lite’s support for Signatures, which allow a single TensorFlow Lite model to support multiple ‘entry’ points. For example, we had defined a train function for on-device training, which is one of the model’s signatures. The train function can be invoked using TensorFlow Lite’s runSignature method by specifying the name of the signature (‘train’):

 // Run training for a few steps.
float[] losses = new float[NUM_EPOCHS];
for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
    for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
        Map<String, Object> inputs = new HashMap<>>();
        inputs.put("x", trainImageBatches.get(batchIdx));
        inputs.put("y", trainLabelBatches.get(batchIdx));

        Map<String, Object> outputs = new HashMap<>();
        FloatBuffer loss = FloatBuffer.allocate(1);
        outputs.put("loss", loss);

        interpreter.runSignature(inputs, outputs, "train");

        // Record the last loss.
        if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);
    }
}


Similarly, the following example shows how to invoke inference using the model’s ‘infer’ signature:

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
    // Restore the weights from the checkpoint file.

    int NUM_TESTS = 10;
    FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
    FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());

    // Fill the test data.

    // Run the inference.
    Map<String, Object> inputs = new HashMap<>>();
    inputs.put("x", testImages.rewind());
    Map<String, Object> outputs = new HashMap<>();
    outputs.put("output", output);
    anotherInterpreter.runSignature(inputs, outputs, "infer");
    output.rewind();

    // Process the result to get the final category values.
    int[] testLabels = new int[NUM_TESTS];
    for (int i = 0; i < NUM_TESTS; ++i) {
        int index = 0;
        for (int j = 1; j < 10; ++j) {
            if (output.get(i * 10 + index) < output.get(i * 10 + j))
                index = testLabels[j];
        }
        testLabels[i] = index;
    }
}

And, that’s it! You now have a TensorFlow Lite model that is able to use on-device training. We hope that this code walkthrough gives you a good idea on how to run on-device training in TensorFlow Lite, and we’re excited to see where you take it.

Practical considerations

In theory, you should be able to apply on-device training in TensorFlow Lite to any use case that TensorFlow supports. However, in reality there are a few practical considerations that you need to keep in mind before you deploy on-device training in your apps:

  • Use cases: The Colab example shows an example of on-device training for a vision use case. If you run into issues for specific models or use cases, please let us know on GitHub.
  • Performance: Depending on the use case, on-device training could take anywhere from a few seconds to much longer. If you run on-device training as part of a user-facing feature (e.g., your end user is interacting with the feature), you should measure the time taken for a wide range of possible training inputs in your app to limit the training time. If your use-case requires very long on-device training times, consider training a model using a desktop or the cloud first, then fine-tuning it on-device.
  • Battery usage: Just like model inference, invoking model training on device may result in a battery drain. If model training is part of a feature that is not user facing, we recommend following Android’s guidelines to implement background tasks.
  • Training from scratch vs. retraining: In theory, it should be possible to train a model from scratch on device using the above features. However, in reality, training from scratch involves an enormous amount of training data and may take several days even on servers with powerful processors. Consequently, for on-device applications, we recommend retraining on an already trained model (i.e., transfer learning) as shown in the Colab example.

Roadmap

Future work includes (but is not limited to) on-device training support on iOS, performance improvements to leverage on-device accelerators (e.g. GPUs) for on-device training, reducing the binary size by implementing more training ops natively in TensorFlow Lite, higher level API support (e.g. via the TensorFlow Lite Task Library) to abstract away the implementation details and examples covering other on-device training use cases (e.g. NLP). Our long term roadmap involves potentially providing on-device end-to-end Federated Learning solutions.

Next steps

Thank you for reading! We’re excited to see what you build using on-device learning. Once again, here are links to the sample app and Colab. If you have any feedback, please let us know on the TensorFlow Forum, or on GitHub.

Acknowledgements

This post reflects the significant contributions of many people in Google’s TensorFlow Lite team including Michelle Carney, Lawrence Chan, Jaesung Chung, Jared Duke, Terry Heo, Jared Lim, Yu-Cheng Ling, Thai Nguyen, Karim Nosseir, Arun Venkatesan, Haoliang Zhang, other TensorFlow Lite team members, and our collaborators in Google Research.

Next post
On-device training in TensorFlow Lite
TensorFlow Lite · Google Article

Posted by the TensorFlow Lite team TensorFlow Lite is Google’s machine learning framework to deploy machine learning models on multiple devices and surfaces such as mobile (iOS and Android), desktops and other edge devices. Recently, we added support to run TensorFlow Lite models in a browser as well. In order to build apps using TensorFlow Lite, you can either use an off-the shelf model from Tens…