August 31, 2022 — Posted by Andreas Steiner and Marc van Zee, Google Research, Brain TeamIntroduction In this blog post we demonstrate how to convert and run Python-based JAX functions and Flax machine learning models in the browser using TensorFlow.js. We have produced three examples of JAX-to-TensorFlow.js conversion each with increasing complexity: A simple JAX function An image classification Flax model traine…
Posted by Andreas Steiner and Marc van Zee, Google Research, Brain Team
In this blog post we demonstrate how to convert and run Python-based JAX functions and Flax machine learning models in the browser using TensorFlow.js. We have produced three examples of JAX-to-TensorFlow.js conversion each with increasing complexity:
For each example, there are Google Colab notebooks you can use to try the JAX-to-TensorFlow.js conversion yourself.
Figure 1. TensorFlow.js model matching user-provided text prompts to a precomputed image embedding (try it out yourself). See Example 3: LiT Demo below for implementation details. |
.json
) directly, so that the model can be used in the browser with Tensorflow.js.tf.Variable
s) and computation.Figure 2. High-level visualization of the conversion steps inside converters.convert_jax(), which converts a JAX function to a Tensorflow.js model. |
weight
and implements a function prod
, which multiplies the input with the parameter (in a real example, params
will contain the all weights of the modules used in the neural network):
def prod(params, xs): return params['weight'] * xs |
params = {'weight': np.array([0.5, 1])} # This represents a batch of 3 inputs, each of length 2. xs = np.arange(6).reshape((3, 2)) prod(params, xs) |
[0.5, 1]
:
[[0. 1.] [1. 3.] [2. 5.]] |
convert_jax
and use the helper function get_tfjs_predict_fn
(which can be found in the Colab), allowing us to verify that the outputs for the JAX function and the web model match. (Note: this helper function will only work in Colab, as it uses some tooling to run the web model using Javascript.)
tfjs.converters.convert_jax( prod, params, input_signatures=[tf.TensorSpec((3, 2), tf.float32)], model_dir=model_dir) tfjs_predict_fn = get_tfjs_predict_fn(model_dir) tfjs_predict_fn(xs) # Same output as JAX. |
None
for the dynamic dimensions in input_signature
. Additionally, one should pass the argument polymorphic_shapes
specifying names for dynamic dimensions. Note that polymorphism is a term coming from type theory, but here we use it to mean that the function works for multiple related shapes, e.g., for multiple batch sizes. This is necessary for shape checking in the JAX function (see Colab for more examples, and here for more documentation on this notation).
tfjs.converters.convert_jax( prod, params, input_signatures=[tf.TensorSpec((None, 2), tf.float32)], polymorphic_shapes=['(b, 2)')], model_dir=model_dir) tfjs_predict_fn = get_tfjs_predict_fn(model_dir) tfjs_predict_fn(np.array([[1., 2.]])) # Outputs: [[0.5, 2. ]] |
train_ds, test_ds = train.get_datasets() state = train.train_and_evaluate(config, workdir=f'./workdir') |
state.apply_fn
that can be used to compute logits for input images. Note that the function expects the first argument to be the model weights state.params
. Given a batch of input images shaped [batch_size, 28, 28, 1
], this will produce the logits for the probability distribution over the ten labels for every model (shaped [batch_size, 10]).
logits = state.apply_fn({'params': state.params}, imgs) |
state.apply_fn()
is then converted exactly the same way as in the previous section – after all, it's a pure function that takes params
and images
as inputs and returns logits
:
tfjs.converters.convert_jax( state.apply_fn, {'params': state.params}, input_signatures=[tf.TensorSpec((1, 28, 28, 1), tf.float32)], model_dir=tfjs_model_dir, ) |
status
text, making sure to give some feedback while the model weights are transferred:
tf.loadGraphModel(modelDir + '/model.json', { onProgress: p => status.innerText = `loading model: ${Math.round(p*100)}%` }) |
img
is a Uint8Array
of length 28*28
, which is first converted to a TensorFlow.js tf.tensor
, before computing the model outputs, and converting them to probabilities via the tf.softmax() function. The output values from the computation are then waited for synchronously by calling .dataSync()
, and converted to JavaScript arrays before they're displayed.
ui.onUpdate(img => { const imgs = tf.tensor(img).cast('float32').reshape([1, 28, 28, 1]) const logits = model.predict(imgs) const preds = tf.softmax(logits) const { values, indices } = tf.topk(preds, 10) ui.showPreds([...values.dataSync()], [...indices.dataSync()]) }) |
Figure 3. Our model from the Colab with 99.1% accuracy on the MNIST test dataset is still surprisingly bad at recognizing hand-written digits. On the left, the model predicts all kinds of digits instead of "one". On the right side, the "one" is drawn more like the data from the training set.Example 3: LiT DemoWriting a more realistic application with a TensorFlow.js model is a bit more involved. This section goes through the main steps that were used to create the demo app from the Google AI blog post Locked-Image Tuning: Adding Language Understanding to Image Models. Refer to that post for technical details on the implementation of the ML model. Also make sure to check out the final LiT Demo. Adapting the modelBefore starting to implement an ML demo, it's a good moment to think carefully about the different options and their respective strengths and weaknesses. At a high level, you have two options: running the ML model on server-side infrastructure, or running the ML model on the edge (i.e. on the visiting user's device).
The model you use for the demo consists of two parts: an image encoder, and a text encoder (see Figure 4). For computing image embeddings you use a large model, and for text embeddings—a small model. To make the demo run faster and produce better results, the expensive image embeddings are pre-computed, so the Tensorflow.js model only needs to compute the text embeddings and then compare the image and text embeddings to compute similarities.
For the demo, we now get those powerful ViT-Large image representations for free, because we can precompute them for all demo images. This allows us to make for a compelling demo with a limited compute budget. In addition to the "tiny" text encoder, we have also prepared a "small" text encoder for the same image embeddings (LiT-L16S), which performs a bit better, but uses more bandwidth to download the model weights, and requires more GPU memory to run on-device. We have evaluated the different models with the code from this Colab:
Note though that the "zeroshot performance" should only be taken as a proxy. In the end, the model performance needs to be good enough for the demo, and in this case our manual testing showed that even the tiny text transformer was able to compute similarities good enough for the demo. Next, we tested the performance of the tiny and small text encoders using this TensorFlow.js benchmark tool on different platforms (using the "custom model" option, and benchmarking 5x16 tokens on the WebGL backend):
Note that the results for the model with the "small" text encoder are missing for "Samsung S21 G5" in the above table because the model did not fit into memory. In terms of performance, the model with the "tiny" text encoder produces results within approximately 0.1-1 seconds, which still feels quite responsive, even on the smallest platform tested. The Lit-LiT web appPreparing the model for this application is a bit more complicated, because we need not only convert the text transformer model weights, but also a matching tokenizer, and the precomputed image embeddings. The Colab loads a LiT model and showcases how to use it, and then prepares contents needed by the web app:
data/ directory and then downloaded as a ZIP file. This file can then be uploaded to a web hosting, from where it is loaded by the web app (for example on GitHub Pages: vision_transformer/lit/data).The code for the entire client-side application is available on GitHub: https://github.com/google-research/big_vision/tree/main/big_vision/tools/lit_demo/. The application is built using Lit web components. The main index.html declares the demo application: |
<lit-demo-app></lit-demo-app> |
lit-demo-app.ts
in the src/components subdirectory, next to all the other web components (image carousel, model controls etc).image-prompts.ts
calls functions from the module src/lit_demo/compute.ts
, which wraps all the TensorFlow.js specific code.
export class Model { /** Tokenizes text. */ tokenize(texts: string[]): tf.Tensor { /* ... */ } /** Computes text embeddings. */ embed(tokens: tf.Tensor): tf.Tensor { return this.model!.execute({inputs: tokens}) as tf.Tensor; } /** Computes similarities texts / pre-computed image embeddings. */ computeSimilarities(texts: string[], imgidxs: number[]) { const textEmbeddings = this.embed(this.tokenize(texts)); const imageEmbeddingsTransposed = tf.transpose( tf.concat(imgidxs.map(idx => tf.slice(this.zimgs!, idx, 1)))); return tf.matMul(textEmbeddings, imageEmbeddingsTransposed); } /** Applies softmax to `computeSimilarities()`. */ computeProbabilities(texts: string[], imgidx: number): number[] { const sims = this.computeSimilarities(texts, [imgidx]); const row = tf.squeeze(tf.slice(tf.transpose(sims), 0, 1)); return [...tf.softmax(tf.mul(this.def!.temperature, row)).dataSync()]; } } |
data/
exported by the Colab above is referenced via the baseUrl in the file src/lit/constants.ts
. By default it refers to the models from the official demo. When replacing the baseUrl
with a different server, make sure to enable cross origin resource sharing.playground.html
as an example, and refer to the instructions in README.md
for how to compile the entire application or the functional part before deploying the application.
<!-- Loads global symbol `lit`. --> <script src="exports_bin.js"></script> <script> async function demo() { lit.setBaseUrl('https://google-research.github.io/vision_transformer/lit'); const model = new lit.Model('tiny'); await model.load(); console.log(model.computeProbabilities(['a dog', 'a cat'], /*imgIdx=*/1); } demo(); </script> |
August 31, 2022 — Posted by Andreas Steiner and Marc van Zee, Google Research, Brain TeamIntroduction In this blog post we demonstrate how to convert and run Python-based JAX functions and Flax machine learning models in the browser using TensorFlow.js. We have produced three examples of JAX-to-TensorFlow.js conversion each with increasing complexity: A simple JAX function An image classification Flax model traine…