the personal blog of Akshay Agrawal

From September 2017 to October 2018, I worked on TensorFlow 2.0 alongside many engineers. In this post, I’ll explain what TensorFlow 2.0 is and how it differs from TensorFlow 1.x. Towards the end, I’ll briefly compare TensorFlow 2.0 to PyTorch 1.0. This post represents my own views; it does not represent the views of Google, my former employer.

Execution model. In TF 2.0, all operations execute imperatively by default. Graphs and the graph runtime are both abstracted away by a just-in-time tracer that translates Python functions executing TF operations into executable graph functions. This means in TF 2.0, there is no Session, and no global graph state. The tracer is exposed as a Python decorator, tf.function. This decorator is for advanced users. Using it is completely optional.

API. TF 2.0 makes tf.kerasthe high-level API for constructing and training neural networks. But you don’t have to use Keras if you don’t want to. You can instead use lower-level operations and automatic differentiation directly.

To follow along with the code examples in this post, install the TF 2.0 alpha.

I. Why TF 2.0?

TF 2.0 largely exists to make TF easier to use, for newcomers and researchers alike.

TF 1.x requires metaprogramming

TF 1.x was designed to train extremely large, static neural networks. Representing a model as a dataflow graph and separating its specification from its execution simplifies training at scale, which explains why TF 1.x uses Python as a declarative metaprogramming tool for graphs.

But most people don’t need to train Google-scale models, and most people find metaprogramming difficult. Constructing a TF 1.x graph is like writing assembly code, and this abstraction is so low-level that it is hard to produce anything but the simplest differentiable programs using it. Programs that have data-dependent control flow are particularly hard to express as graphs.

Metaprogramming is (often) unnecessary

It is possible to implement automatic differentiation by tracing computations while they are executed, without static graphs; Chainer, PyTorch, and autograd do exactly that. These libraries are substantially easier to use than TF 1.x, since imperative programming is so much more natural than declarative programming. Moreover, when training models with large operations on a single machine, these graph-free libraries are competitive with TF 1.x performance. For these reasons, TF 2.0 privileges imperative execution.

Graphs are still sometimes useful, for distribution, serialization, code generation, deployment, and (sometimes) performance. That’s why TF 2.0 provides the just-in-time tracer tf.function, which transparently converts Python functions into functions backed by graphs. This tracer also rewrites tensor-dependent Python control flow to TF control flow, and it automatically adds control dependencies to order reads and writes to TF state. This means that constructing graphs via tf.function is much easier than constructing TF 1.x graphs manually.

Multi-stage programming

The ability to create polymorphic graph functions via tf.function at runtime makes TF 2.0 similar to a multi-stage programming language.

For TF 2.0, I recommend the following multi-stage workflow. Start by implementing your program in imperative mode. Once you’re satisfied that your program is correct, measure its performance. If the performance is unsatisfactory, analyze your program using cProfile or a comparable tool to find bottlenecks consisting of TF operations. Next, refactor the bottlenecks into Python functions, and stage these functions in graphs with tf.function.

A multi-stage workflow for TF 2.0.

If you mostly use TF 2.0 to train large deep models, you probably won’t need to analyze or stage your programs. If on the other hand you write programs that execute lots of small operations, like MCMC samplers or reinforcement learning algorithms, you’ll likely find this workflow useful. In such cases, the Python overhead incurred by executing operations eagerly actually matters.

II. Imperative execution

In TF 2.0, all operations are executed imperatively, or “eagerly”, by default. If you’ve used NumPy or PyTorch, TF 2.0 will feel familiar. For example, the following line of code will immediately construct two tensors backed by numerical tensors and then execute the add operation.

In TF 2.0, there are no placeholders, no sessions, and no feed dicts. Because operations are executed immediately, you can use (and differentiate through) if statements and for loops (no more tf.cond or tf.while_loop). You can also use whatever Python data structures you like, and debug your programs with print statements and pdb.

If TF detects that a GPU is available, it will automatically run operations on the GPU when possible. The target device can also be controlled explicitly.

III. State

Using tf.Variable objects in TensorFlow required wrangling global collections of graph state, with confusing APIs like tf.get_variable, tf.variable_scope, and tf.initializers.global_variables. TF 2.0 does away with global collections and their associated APIs. If you need a tf.Variable in TF 2.0, then you just construct and initialize it directly:

IV. Automatic differentiation

TF 2.0 implements reverse-mode automatic differentiation (also known as backpropagation), using a trace-based mechanism. This trace, or tape, is exposed as a context manager, tf.GradientTape. The watch method designates a Tensor as something that we’ll need to differentiate with respect to later. Notice that by tracing the computation of dy_dx under the first tape, we’re able to compute d2y_dx2.

V. Keras

TF 1.x is notorious for having many mutually incompatible high-level APIs for neural networks. TF 2.0 has just one high-level API: tf.keras, which essentially implements the Keras API but is customized for TF. Several standard layers for neural networks are available in the tf.keras.layers namespace.

Keras layers can be composed via tf.keras.Sequential() to obtain an object representing their composition. For example, the below code trains a toy CNN on MNIST. (Of course, MNIST can be solved by much simpler methods, like least squares.)

If you don’t want to use tf.keras, you can use low-level APIs like tf.reshape, tf.nn.conv2d, tf.nn.max_pool, tf.nn.dropout, and tf.matmul directly.

VI. Graph functions

For advanced users who need graphs, TF 2.0 provides tf.function, a just-in-time tracer that converts Python functions that execute TensorFlow operations into graph functions. A graph function is a TF graph with named inputs and outputs. Graph functions are executed by a C++ runtime that automatically partitions graphs across devices, and it parallelizes and optimizes them before execution.

Calling a graph function is syntactically equivalent to calling a Python function. Here’s a very simple example.

Every time a graph function is called, its “input signature” is analyzed. If the input signature doesn’t match an input signature it has seen before, it re-traces the Python function and constructs another concrete graph function. (In programming languages terms, this is like multiple dispatch or lightweight modular staging.) This means that for one Python function, many concrete graph functions might be constructed. This also means that every call that triggers a trace will be slow, but subsequent calls with the same input signature will be much faster.

Lexical closure, state, and control dependencies

Graph functions support lexically closing over tf.Tensor and tf.Variable objects. You can mutate tf.Variable objects inside a graph function, and tf.function will automatically add the control dependencies needed to ensure that your reads and writes happen in program-order.

Python control flow

tf.function automatically rewrites Python control flow that depends on tf.Tensor data into graph control flow, using autograph. This means that you no longer need to use constructs like tf.cond and tf.while_loop. For example, if we were to translate the following function into a graph function via tf.function, autograph would convert the for loop into a tf.while_loop, because it depends on tf.range(100), which is a tf.Tensor.

Performance

Graph functions can provide significant speed-ups for programs that execute many small TF operations. For these programs, the Python overhead incurred executing an operation imperatively outstrips the time spent running the operations. As an example, let’s benchmark the matmul_many function imperatively and as a graph function.

graph_fn = tf.function(matmul_many)

Here’s the imperative (Python) performance.

%%timeit
matmul_many(tf.ones([2, 2]))

100 loops, best of 3: 13.5 ms per loop

The first call to graph_fn is slow, since this is when the graph function is generated.

But subsequent calls are an order of magnitude faster than imperatively executing matmul_many.

%%timeit
graph_fn(tf.ones([2, 2]))

1000 loops, best of 3: 1.97 ms per loop

VII. Comparison to other Python libraries

There are many libraries for machine learning. Out of all of them, PyTorch 1.0 is the one that’s most similar to TF 2.0. Both TF 2.0 and PyTorch 1.0 execute imperatively by default, and both provide ways to transform Python functions into graph-backed functions (compare tf.function and torch.jit). The PyTorch JIT tracer, torch.jit.trace, doesn’t implement the multiple-dispatch semantics that tf.function does, and it also doesn’t rewrite the AST. On the other hand, TorchScript lets you use Python control flow, but unlike tf.function, it doesn’t let you mix in arbitrary Python code that parametrizes the construction of your graph. That means that in comparison to tf.function, TorchScript makes it harder for you to shoot yourself in the foot, while potentially limiting your creative expression.

So should you use TF 2.0, or PyTorch 1.0? It depends. Because TF 2.0 is in alpha, it still has some kinks, and its imperative performance still needs work. But you can probably count on TF 2.0 becoming stable sometime this year. If you’re in industry, TensorFlow has TFX for production pipelines, TFLite for deploying to mobile, and TensorFlow.js for the web. PyTorch recently made a commitment to production; since then, they’ve added C++ inference and deployment solutions for several cloud providers. For research, I’ve found that TF 2.0 and PyTorch 1.0 are sufficiently similar that I’m comfortable using either one, and my choice of framework depends on my collaborators.

The multi-stage approach of TF 2.0 is similar to what’s done in JAX. JAX is great if you want a functional programming model that looks exactly like NumPy, but with automatic differentiation and GPU support; this is, in fact, what many researchers want. If you don’t like functional programming, JAX won’t be a good fit.

VIII. Domain-specific languages for machine learning

TF 2.0 and PyTorch 1.0 are very unusual libraries. It has been observed that these libraries resemble domain-specific languages (DSLs) for automatic-differentiation and machine learning, embedded in Python (see also our paper on TF Eager, TF 2.0’s precursor). What TF 2.0 and PyTorch 1.0 accomplish in Python is impressive, but they’re pushing the language to its limits.

There is now significant work underway to embed ML DSLs in languages that are more amenable to compilation than Python, like Swift (DLVM, Swift for TensorFlow,MLIR), and Julia (Flux, Zygote). So while TF 2.0 and PyTorch 1.0 are great libraries, do stay tuned: over the next year (or two, or three?), the ecosystem of programming languages for machine learning will continue to evolve rapidly.

You’re welcome! In TF 1.x, the TensorFlow Debugger can be used to inspect Keras models (see https://www.tensorflow.org/guide/debugger). But I’m not sure whether the TF debugger is supported in 2.0. Because everything executes eagerly in 2.0, you shouldn’t really need to use a TF-specific debugger — you can just use Python’s `pdb`.