Pushing the limits of GPU performance with XLA
novembro 14, 2018
Posted by Toby Boyd, Yanan Cao, Sanjoy Das, Thomas Joerg, Justin Lebar

XLA is a compiler for TensorFlow graphs that you can use to accelerate your TensorFlow ML models today with minimal source code changes. This post describes what XLA is and shows how you can try it out on your own code.

TensorFlow 1.12 (with XLA) achieves significant performance gains over TF 1.11 (without XLA) on ResNet50 v1.0 training on NVIDIA® Tesla® V100 GPUs: 10,526 images/sec with synthetic data and 10,267 images/sec with real data (see appendix for reproduction instructions). We have observed speedups ranging from 1.13x to 3.04x on a variety of internal models.
Chart 1: Bar graph showing performance on ResNet50v1 training with synthetic data, comparing TensorFlow v1.11 without XLA vs TensorFlow v1.12 with XLA. One GPU: 888 images/sec without XLA, 1,401 images/sec with.
8 GPUs: 6,818 images/sec without XLA, 10,526 images/sec with. Chart 2: Bar graph showing performance on ResNet50v1 training with real data, comparing TensorFlow v1.11 without XLA vs TensorFlow v1.12 with XLA.
One GPU: 871 images/sec without XLA, 1,395 images/sec with.
8 GPUs: 6,413 images/sec without XLA, 10,268 images/sec with.

xXLA: TensorFlow, Compiled!

Normally when you run a TensorFlow graph, all of the operations are executed individually by the TensorFlow graph executor. Each op has a precompiled GPU kernel implementation (shipped as part of the TensorFlow binary) that the graph executor dispatches to.

XLA provides an alternative mode of running TF models: It compiles your TensorFlow graph into a sequence of GPU kernels generated specifically for your model. Because these kernels are unique to your program, they can exploit model-specific information for optimization.

As an example, let’s look at an optimization XLA does in the context of a simple TensorFlow computation:
def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)
Run without XLA, the graph launches three kernels: one for the multiplication, one for the addition and one for the reduction.

However, XLA can optimize the graph so that it computes the result in a single kernel launch. It does this by “fusing” the addition, multiplication and reduction into a single GPU kernel. Moreover, this fused operation does not write out the intermediate values produced by y*z and x+y*z to memory; instead it “streams” the results of these intermediate computations directly to their users while keeping them entirely in GPU registers.

Fusion is XLA’s single most important optimization. Memory bandwidth is typically the scarcest resource on hardware accelerators, so removing memory operations is one of the best ways to improve performance.

Using XLA in your models

XLA exposes an API, xla.compile, that lets you explicitly invoke the XLA compiler on a part of your TensorFlow graph. xla.compile accepts a Python function that generates a TensorFlow computation and wires up the generated computation to be compiled by XLA. xla.compile returns a list of tensors, each corresponding to an output from the computation constructed by the function passed in, but now optimized by XLA.

So the computation generated by model_fn above can be run with XLA by invoking xla.compile as follows:
from tensorflow.contrib.compiler import xla

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

def create_and_run_graph():
  with tf.Session() as sess:
    x = tf.placeholder(tf.float32, name='x')
    y = tf.placeholder(tf.float32, name='y')
    z = tf.placeholder(tf.float32, name='z')
    result = xla.compile(computation=model_fn, inputs=(x, y, z))[0]
    # `result` is a normal Tensor (albeit one that is computed by an XLA
    # compiled executable) and can be used like any other Tensor.
    result = tf.add(result, result)
    return sess.run(result, feed_dict={ ... })
You can use a command line flag (or other arbitrary logic) to control whether your computation is compiled by XLA or not. It is common for models to call xla.compile as
if should_use_xla():
  result = xla.compile(model_fn, (x, y, z))[0]
else:
  result = model_fn(x, y, z)
You can use a command line flag (or other arbitrary logic) to control whether your computation is compiled by XLA or not. It is common for models to call xla.compile as
if should_use_xla():
  result = xla.compile(model_fn, (x, y, z))[0]
else:
  result = model_fn(x, y, z)
which allows for easy experimentation.

We have set up a colab in which you can play with xla.compile on a slightly more complex model.

xla.compile is not the only way to invoke XLA on a TensorFlow subgraph; specifically, there are ways to ask TensorFlow to automatically find XLA compatible subgraphs and compile them using XLA, but we won’t discuss them in this post.

Caveats to using XLA

Firstly, the XLA GPU backend is experimental at this time — while we’re not aware of any major problems, it hasn’t been tested with extensive production use.

Secondly, xla.compile does not yet work with Keras high-level APIs like model.fit (though you can use Keras ops), or in eager mode. We’re actively working on APIs to enable XLA in these modes; stay tuned.

Thirdly, XLA cannot compile all TensorFlow graphs; only graphs with the following properties can be passed to xla.compile.

All operations must have inferrable shapes

XLA needs to be able to infer the shapes for all of operations it compiles given the inputs to the computation. So a model function that produces a Tensor with an unpredictable shape will fail with an error when run. (In this example, the shape of the output of tf.expand_dims depends on random_dim_size which cannot be inferred given x, y and z.)

Note that because XLA is a JIT compiler, the shapes can vary across runs, as long as they can be inferred given the inputs to the cluster. So this example is fine.

All operations must be supported by XLA

Not all TensorFlow operations can be compiled by XLA and if your model has an operation that XLA does not support, XLA compilation will fail. For instance, XLA does not support the tf.where op, so if your model function includes this op, it will fail when run with xla.compile.

Every TensorFlow operation supported by XLA has a REGISTER_XLA_OP invocation in tensorflow/compiler/tf2xla/kernels/ and so you can grep for instances of the REGISTER_XLA_OP macro to find the list of supported TensorFlow operations.

Appendix

Performance on Google benchmarks

Below is a plot of the relative speedup/slowdown of TensorFlow with XLA vs TensorFlow without XLA on all of the XLA team’s benchmark models, run on a V100 GPU. We aren’t holding anything back; this is the full set of benchmarks that we use in evaluating the compiler today.
Chart showing the speedup/slowdown of TensorFlow plus XLA vs TensorFlow without XLA on Google-internal benchmarks. The data is a list of results for fp16 and fp32 models, sorted by speedup. fp32 results: [0.86 0.94 0.94 0.97 0.98 0.99 0.99 0.99 1.00 1.01 1.01 1.01 1.01 1.02 1.04 1.05 1.06 1.06 1.07 1.07 1.08 1.08 1.08 1.09 1.09 1.10 1.10 1.11 1.11 1.11 1.12 1.12 1.12 1.13 1.15 1.15 1.18 1.18 1.20 1.27 1.30 1.30 1.32 1.37 1.40 1.41 1.43 1.44 1.52], fp16 results: [1.10 1.32 1.41 1.47 1.48 1.55 1.56 1.59 1.63 1.64 1.64 1.67 2.07 2.51 3.09]
Each bar represents a full model, e.g. “resnet50 training images/sec” or “inference throughput on a Google-internal model”. The X axis is sorted by speedup.

Your mileage may vary, especially since we’ve made optimizations to XLA specifically motivated by many of these benchmarks! Nonetheless many of them have worked well out-of-the-box, and we continue to improve.

Reproducing ResNet50 v1.0 benchmark

The sections below walk through setting up a Google Cloud instance and executing the ResNet50 benchmark.

Prepare the data

This step is only needed for a real data test and can take a few hours. We recommend doing this on a CPU only instance to reduce compute cost. Using the instructions for imagenet_to_gcs.py create the imagenet data in TFRecord format and push it to a Google Cloud Storage Bucket.

Create GCE Instance

The snippet below creates an instance of the Google Deep Learning VM on the Google Cloud Platform with eight Tesla® V100 GPUs.
export INSTANCE_NAME="xla-benchmark-8xV100"
export IMAGE_FAMILY="tf-1-12-cu100"
export PROJECT_NAME=""

gcloud beta compute instances create $INSTANCE_NAME \
  --project=$PROJECT_NAME \
  --machine-type=n1-standard-64 \
  --maintenance-policy=TERMINATE \
  --accelerator=type=nvidia-tesla-v100,count=8 \
  --tags=http-server,https-server \
  --image-family=$IMAGE_FAMILY \
  --image-project=deeplearning-platform-release \
  --boot-disk-size=100GB \
  --boot-disk-type=pd-ssd \
  --local-ssd interface=nvme \
  --local-ssd interface=nvme \
  --local-ssd interface=nvme \
  --local-ssd interface=nvme \
  --metadata install-nvidia-driver=True

## Combines the 4 local nvme SSD drives into a single RAID 0 drive.
# Install raid management tool.
sudo apt-get update && sudo apt-get install mdadm --no-install-recommends

# Creates RAID 0 array.
sudo mdadm --create /dev/md0 --level=0 --raid-devices=4 \
/dev/nvme0n1 /dev/nvme0n2 /dev/nvme0n3 /dev/nvme0n4 

# Formats and mounts the array.
sudo mkfs.ext4 -F /dev/md0
sudo mkdir -p /data/imagenet
sudo mount /dev/md0 /data
sudo chmod a+w /data

# Installs custom TensorFlow 1.12 binary with AVX2.  Binary included on
# the image already has XLA but the custom binary is compiled with AVX2.
sudo pip install --force-reinstall https://storage.googleapis.com/tf-performance/tf_binary/tensorflow-1.12.0.a6d8ffa.AVX2.CUDA10-cp27-cp27mu-linux_x86_64.whl

Execute benchmarks

gcloud compute ssh $INSTANCE_NAME

# Clone TensorFlow benchmark repository.
git clone https://github.com/tensorflow/benchmarks.git && cd benchmarks
git reset --hard 1e7d788042dfc6d5e5cd87410c57d5eccee5c664
cd scripts/tf_cnn_benchmarks

## Synthetic data test

# 8 GPUs
python tf_cnn_benchmarks.py \
    --batch_size=364 \
    --num_batches=100 \
    --model=resnet50 \
    --optimizer=momentum \
    --variable_update=replicated \
    --all_reduce_spec=nccl \
    --use_fp16=True \
    --nodistortions \
    --gradient_repacking=2 \
    --compute_lr_on_cpu=True \
    --single_l2_loss_op=True \
    --xla_compile=True \
    --num_gpus=8 \
    --loss_type_to_report=base_loss

# 1 GPU
python tf_cnn_benchmarks.py \
    --batch_size=364 \
    --num_batches=100 \
    --model=resnet50 \
    --optimizer=momentum \
    --use_fp16=True \
    --nodistortions \
    --compute_lr_on_cpu=True \
    --single_l2_loss_op=True \
    --xla_compile=True \
    --loss_type_to_report=base_loss

## Real data test
# add --data_dir=/data/imagenet to the 1 or 8 GPU command.
Next post
Pushing the limits of GPU performance with XLA

Posted by Toby Boyd, Yanan Cao, Sanjoy Das, Thomas Joerg, Justin Lebar

XLA is a compiler for TensorFlow graphs that you can use to accelerate your TensorFlow ML models today with minimal source code changes. This post describes what XLA is and shows how you can try it out on your own code.

TensorFlow 1.12 (with XLA) achieves significant performance gains over TF 1.11 (without XLA) on ResNet50 v1.0…