Pushing the limits of GPU performance with XLA

XLA is a compiler for TensorFlow graphs that you can use to accelerate your TensorFlow ML models today with minimal source code changes. This post describes what XLA is and shows how you can try it out on your own code.

XLA: TensorFlow, Compiled!

Normally when you run a TensorFlow graph, all of the operations are executed individually by the TensorFlow graph executor. Each op has a precompiled GPU kernel implementation (shipped as part of the TensorFlow binary) that the graph executor dispatches to.

XLA provides an alternative mode of running TF models: It compiles your TensorFlow graph into a sequence of GPU kernels generated specifically for your model. Because these kernels are unique to your program, they can exploit model-specific information for optimization.

As an example, let’s look at an optimization XLA does in the context of a simple TensorFlow computation:

Run without XLA, the graph launches three kernels: one for the multiplication, one for the addition and one for the reduction.

However, XLA can optimize the graph so that it computes the result in a single kernel launch. It does this by “fusing” the addition, multiplication and reduction into a single GPU kernel. Moreover, this fused operation does not write out the intermediate values produced by y*z and x+y*z to memory; instead it “streams” the results of these intermediate computations directly to their users while keeping them entirely in GPU registers.

Fusion is XLA’s single most important optimization. Memory bandwidth is typically the scarcest resource on hardware accelerators, so removing memory operations is one of the best ways to improve performance.

Using XLA in your models

XLA exposes an API, xla.compile, that lets you explicitly invoke the XLA compiler on a part of your TensorFlow graph. xla.compile accepts a Python function that generates a TensorFlow computation and wires up the generated computation to be compiled by XLA. xla.compile returns a list of tensors, each corresponding to an output from the computation constructed by the function passed in, but now optimized by XLA.

So the computation generated by model_fn above can be run with XLA by invoking xla.compile as follows:

You can use a command line flag (or other arbitrary logic) to control whether your computation is compiled by XLA or not. It is common for models to call xla.compile as

which allows for easy experimentation.

We have set up a colab in which you can play with xla.compile on a slightly more complex model.

xla.compile is not the only way to invoke XLA on a TensorFlow subgraph; specifically, there are ways to ask TensorFlow to automatically find XLA compatible subgraphs and compile them using XLA, but we won’t discuss them in this post.

Caveats to using XLA

Firstly, the XLA GPU backend is experimental at this time — while we’re not aware of any major problems, it hasn’t been tested with extensive production use.

Secondly, xla.compile does not yet work with Keras high-level APIs like model.fit (though you can use Keras ops), or in eager mode. We’re actively working on APIs to enable XLA in these modes; stay tuned.

Thirdly, XLA cannot compile all TensorFlow graphs; only graphs with the following properties can be passed to xla.compile.

All operations must have inferrable shapes

XLA needs to be able to infer the shapes for all of operations it compiles given the inputs to the computation. So a model function that produces a Tensor with an unpredictable shape will fail with an error when run. (In this example, the shape of the output of tf.expand_dims depends on random_dim_size which cannot be inferred given x, y and z.)

Note that because XLA is a JIT compiler, the shapes can vary across runs, as long as they can be inferred given the inputs to the cluster. So this example is fine.

All operations must be supported by XLA

Not all TensorFlow operations can be compiled by XLA and if your model has an operation that XLA does not support, XLA compilation will fail. For instance, XLA does not support the tf.where op, so if your model function includes this op, it will fail when run with xla.compile.

Every TensorFlow operation supported by XLA has a REGISTER_XLA_OP invocation in tensorflow/compiler/tf2xla/kernels/ and so you can grep for instances of the REGISTER_XLA_OP macro to find the list of supported TensorFlow operations.

Appendix

Performance on Google benchmarks

Below is a plot of the relative speedup/slowdown of TensorFlow with XLA vs TensorFlow without XLA on all of the XLA team’s benchmark models, run on a V100 GPU. We aren’t holding anything back; this is the full set of benchmarks that we use in evaluating the compiler today.

Each bar represents a full model, e.g. “resnet50 training images/sec” or “inference throughput on a Google-internal model”. The X axis is sorted by speedup.

Your mileage may vary, especially since we’ve made optimizations to XLA specifically motivated by many of these benchmarks! Nonetheless many of them have worked well out-of-the-box, and we continue to improve.

Reproducing ResNet50 v1.0 benchmark

The sections below walk through setting up a Google Cloud instance and executing the ResNet50 benchmark.

Prepare the data

This step is only needed for a real data test and can take a few hours. We recommend doing this on a CPU only instance to reduce compute cost. Using the instructions for imagenet_to_gcs.py create the imagenet data in TFRecord format and push it to a Google Cloud Storage Bucket.