11월 14, 2018 —
Posted by Toby Boyd, Yanan Cao, Sanjoy Das, Thomas Joerg, Justin Lebar
XLA is a compiler for TensorFlow graphs that you can use to accelerate your TensorFlow ML models today with minimal source code changes. This post describes what XLA is and shows how you can try it out on your own code.
TensorFlow 1.12 (with XLA) achieves significant performance gains over TF 1.11 (without XLA) on ResNet50 v1.0…
Chart 1: Bar graph showing performance on ResNet50v1 training with synthetic data, comparing TensorFlow v1.11 without XLA vs TensorFlow v1.12 with XLA. One GPU: 888 images/sec without XLA, 1,401 images/sec with. 8 GPUs: 6,818 images/sec without XLA, 10,526 images/sec with. Chart 2: Bar graph showing performance on ResNet50v1 training with real data, comparing TensorFlow v1.11 without XLA vs TensorFlow v1.12 with XLA. One GPU: 871 images/sec without XLA, 1,395 images/sec with. 8 GPUs: 6,413 images/sec without XLA, 10,268 images/sec with. |
def model_fn(x, y, z):
return tf.reduce_sum(x + y * z)
Run without XLA, the graph launches three kernels: one for the multiplication, one for the addition and one for the reduction.from tensorflow.contrib.compiler import xla
def model_fn(x, y, z):
return tf.reduce_sum(x + y * z)
def create_and_run_graph():
with tf.Session() as sess:
x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')
z = tf.placeholder(tf.float32, name='z')
result = xla.compile(computation=model_fn, inputs=(x, y, z))[0]
# `result` is a normal Tensor (albeit one that is computed by an XLA
# compiled executable) and can be used like any other Tensor.
result = tf.add(result, result)
return sess.run(result, feed_dict={ ... })
You can use a command line flag (or other arbitrary logic) to control whether your computation is compiled by XLA or not. It is common for models to call xla.compile asif should_use_xla():
result = xla.compile(model_fn, (x, y, z))[0]
else:
result = model_fn(x, y, z)
You can use a command line flag (or other arbitrary logic) to control whether your computation is compiled by XLA or not. It is common for models to call xla.compile asif should_use_xla():
result = xla.compile(model_fn, (x, y, z))[0]
else:
result = model_fn(x, y, z)
which allows for easy experimentation.Chart showing the speedup/slowdown of TensorFlow plus XLA vs TensorFlow without XLA on Google-internal benchmarks. The data is a list of results for fp16 and fp32 models, sorted by speedup. fp32 results: [0.86 0.94 0.94 0.97 0.98 0.99 0.99 0.99 1.00 1.01 1.01 1.01 1.01 1.02 1.04 1.05 1.06 1.06 1.07 1.07 1.08 1.08 1.08 1.09 1.09 1.10 1.10 1.11 1.11 1.11 1.12 1.12 1.12 1.13 1.15 1.15 1.18 1.18 1.20 1.27 1.30 1.30 1.32 1.37 1.40 1.41 1.43 1.44 1.52], fp16 results: [1.10 1.32 1.41 1.47 1.48 1.55 1.56 1.59 1.63 1.64 1.64 1.67 2.07 2.51 3.09] |
export INSTANCE_NAME="xla-benchmark-8xV100"
export IMAGE_FAMILY="tf-1-12-cu100"
export PROJECT_NAME=""
gcloud beta compute instances create $INSTANCE_NAME \
--project=$PROJECT_NAME \
--machine-type=n1-standard-64 \
--maintenance-policy=TERMINATE \
--accelerator=type=nvidia-tesla-v100,count=8 \
--tags=http-server,https-server \
--image-family=$IMAGE_FAMILY \
--image-project=deeplearning-platform-release \
--boot-disk-size=100GB \
--boot-disk-type=pd-ssd \
--local-ssd interface=nvme \
--local-ssd interface=nvme \
--local-ssd interface=nvme \
--local-ssd interface=nvme \
--metadata install-nvidia-driver=True
## Combines the 4 local nvme SSD drives into a single RAID 0 drive.
# Install raid management tool.
sudo apt-get update && sudo apt-get install mdadm --no-install-recommends
# Creates RAID 0 array.
sudo mdadm --create /dev/md0 --level=0 --raid-devices=4 \
/dev/nvme0n1 /dev/nvme0n2 /dev/nvme0n3 /dev/nvme0n4
# Formats and mounts the array.
sudo mkfs.ext4 -F /dev/md0
sudo mkdir -p /data/imagenet
sudo mount /dev/md0 /data
sudo chmod a+w /data
# Installs custom TensorFlow 1.12 binary with AVX2. Binary included on
# the image already has XLA but the custom binary is compiled with AVX2.
sudo pip install --force-reinstall https://storage.googleapis.com/tf-performance/tf_binary/tensorflow-1.12.0.a6d8ffa.AVX2.CUDA10-cp27-cp27mu-linux_x86_64.whl
gcloud compute ssh $INSTANCE_NAME
# Clone TensorFlow benchmark repository.
git clone https://github.com/tensorflow/benchmarks.git && cd benchmarks
git reset --hard 1e7d788042dfc6d5e5cd87410c57d5eccee5c664
cd scripts/tf_cnn_benchmarks
## Synthetic data test
# 8 GPUs
python tf_cnn_benchmarks.py \
--batch_size=364 \
--num_batches=100 \
--model=resnet50 \
--optimizer=momentum \
--variable_update=replicated \
--all_reduce_spec=nccl \
--use_fp16=True \
--nodistortions \
--gradient_repacking=2 \
--compute_lr_on_cpu=True \
--single_l2_loss_op=True \
--xla_compile=True \
--num_gpus=8 \
--loss_type_to_report=base_loss
# 1 GPU
python tf_cnn_benchmarks.py \
--batch_size=364 \
--num_batches=100 \
--model=resnet50 \
--optimizer=momentum \
--use_fp16=True \
--nodistortions \
--compute_lr_on_cpu=True \
--single_l2_loss_op=True \
--xla_compile=True \
--loss_type_to_report=base_loss
## Real data test
# add --data_dir=/data/imagenet to the 1 or 8 GPU command.
11월 14, 2018
—
Posted by Toby Boyd, Yanan Cao, Sanjoy Das, Thomas Joerg, Justin Lebar
XLA is a compiler for TensorFlow graphs that you can use to accelerate your TensorFlow ML models today with minimal source code changes. This post describes what XLA is and shows how you can try it out on your own code.
TensorFlow 1.12 (with XLA) achieves significant performance gains over TF 1.11 (without XLA) on ResNet50 v1.0…