Jasper Van den Bossche
Software Engineer
Training your neural network ten times faster using Jax on a TPU
All the cool kids seem to be raving about JAX these days. Deepmind is using it extensively for their research and even building their own ecosystem on top of it. Boris Dayma and his team built DALL·E Mini in no time using JAX and TPUs, definitely worth checking out on Hugging Face where you already find over 5000 models written in JAX. But what exactly is JAX and why is it so special? According to their website JAX offers automatic differentiation, vectorization and just-in-time compilation to both GPUs and TPUs via composable transformations. Sounds complicated? Don’t worry, in this blogpost we will take you on a tour and show you how JAX works, how it’s different from Tensorflow/Pytorch and why we think it is a super interesting framework.
JAX is a high-performance numerical computing and machine learning framework by Google Research that runs super fast on GPUs and TPUs, without having to worry about low-level details. The goal of JAX was to build a framework that combines high-performance with Python’s expressiveness and ease of use, so that researchers could experiment with new models and techniques without the need of highly optimised low-level C/C++ implementations. It achieves this goal by using Google’s XLA (Accelerated Linear Algebra) compiler to generate efficient machine code rather than using precompiled kernels. One of the cool things about JAX is that it is accelerator agnostic, meaning that the same Python code can run efficiently on both GPUs and TPUs.
JAX works via composable function transformations, this means that JAX takes a function and produces a new function that is interpreted differently and that multiple transformations can be chained together. Automatic differentiation for instance is a transformation that generates the derivative of a function, while automatic vectorization takes a function that operates on a single data point and transforms it into a function that operates on a batch of data points. Through these transformations JAX allows the programmer to stay in the high-level Python world and let the compiler do the hard work by generating the highly efficient code needed to train complex models. We will go over these transformations and apply them in an example where we build a simple multilayer perceptron.
JAX is a compiler-oriented framework, which means that a compiler is responsible for transforming the Python functions into efficient machine code. Tensorflow and Pytorch on the other hand have precompiled GPU and TPU kernels for each operation. During the execution of a TensorFlow program, each operation is dispatched individually. While the operations themselves are very well optimised, fusing them together requires a lot of memory operations, causing a bottleneck in the performance. The XLA compiler can generate code for the entire function. It can use all of that information to fuse together operations and save a ton of memory operations and thus generate overall faster code.
JAX is also more lightweight than Tensorflow and Pytorch, because there is no need to implement each operation, function or model separately. Instead JAX implements the NumPy API with simpler and more low-level operations that can be used as building blocks and fused together into complex models and functions by the compiler.
The compiler-oriented design is a lot more powerful than you might think at first. With the compiler, there is no longer a need to implement low-level accelerator code. It allows researchers to vastly improve their productivity and opens doors to experiment with new model architectures. Researchers are even able to experiment with GPUs and TPUs without the need to rewrite their code. But how does that work?
JAX doesn’t directly compile to machine code, but rather to an intermediate representation that is independent from the high-level Python code and the machine code. The compiler is split up in a frontend that compiles Python functions to the IR and a backend that compiles the IR to platform specific machine code. This design isn’t new, an example of a compiler that also follows this design is LLVM. There are frontends for both C and Rust that translate high-level code to the LLVM IR. The backend can then generate machine code for a variety of supported machine types, no matter whether the original code was written in C or Rust.
This is huge, because thanks to this flexible design one could build a new accelerator, write an XLA backend for it and your JAX code that previously ran on GPUs/TPUs can be executed on the new accelerator. On the other hand, you could also build a framework in another programming language that compiles to the JAX IR and you can make use of GPUs and TPUs thanks to XLA.
If this compiler based approach works so much better than precompiled kernels, why didn’t Tensorflow and Pytorch make use of it from the start? The answer is pretty simple, it is really hard to design a good numerical compiler. With its automatic differentiation, vectorization and jit-compilation, JAX has some really powerful tools under its belt. However JAX isn’t the silver bullet either, all of these goodies come at a small price, you need to learn a few new tricks and concepts related to functional programming.
JAX can’t transform just any Python function, it can only transform pure functions. A pure function can be defined as a function that only depends on its inputs, meaning that for a given input x it will always return the same output y and that it doesn’t produce any side effects such as IO operations or mutation of global variables. Python’s dynamism means that the behaviour of a function changes based on the types of its inputs and JAX wants to exploit this dynamism by transforming functions at runtime. At the start of a transformation JAX checks what the function does for a set of given inputs and transforms the function based on that information. Under the hood JAX traces the function, just like the Python interpreter. By only allowing pure functions, transforming functions just in time becomes a lot easier and faster.
Imagine that the tracer has to deal with side effects such as IO, that means that unexpected behaviour can occur such as a user that entered invalid data, which makes it a lot harder to generate efficient code, especially when accelerators are in the game. Global variables can change between two function calls and thus completely change the behaviour of the function they’re used in, making a transformed function invalid. If you are interested in compilers and the nitty gritty details of how JAX’s tracing works, we recommend you to check out the documentation for more details of its inner workings.
The only unfortunate thing about JAX is that it can’t verify whether a function is a pure function. It is up to the programmer to make sure that he writes pure functions, otherwise JAX will transform the function with some unexpected behaviour.
Working with pure functions also has an impact on how data structures are used. In other frameworks machine learning models are often represented in a stateful way, however, this clashes with the functional programming paradigm because this is mutation of a global state. To overcome this problem JAX introduces pytrees, tree-like structures built out of container-like Python objects. Container-like classes can be registered in the pytree registry, which by default contains lists, tuples, and dicts. Pytrees can contain other pytrees and classes not registered in the pytree registry are considered leafs. Leafs can be considered as immutable inputs for a pure function. For each class in the pytree registry, there is a function that converts a pytree to a tuple with its children and optional metadata as well as a function that converts children and metadata back to a container-like type. These functions can be used to update the model or any other stateful objects you use.
Before we dive into our MLP example, we will show the most important transformations in JAX.
The first transformation is automatic differentiation, where we take a Python function as an input and return a function that represents the gradient of that function. The neat thing about JAX’s autodiff is that it can differentiate Python functions that make use and Python containers, conditionals, loops etc. In the following example we create a function that represents the gradient of the `tanh` function. Because JAX transformations are composable, we can use n nested calls of the grad function to transform to calculate the nth derivative.
JAX’s Automatic differentiation is a powerful and extensive tool, if you want to learn more about how it works we recommend you to read The JAX Autodiff Cookbook.
When training a model, you typically propagate a batch of training samples through your model. When implementing a model would you thus have to think of your prediction function as one that takes in a batch of samples and returns a prediction for each sample. This however can significantly increase the difficulty of the implementation as well as reduce the readability of the function compared to a function that would operate on a single sample. In comes the second transformation: automatic vectorization. We write our function as if we were processing only a single sample, then vmap will transform it into a vectorized version.
In the beginning vmap can be a bit tricky especially when working with higher dimensions, but it is a really powerful transformation. We recommend you to check out some examples in the documentation to fully understand its potential.
The third function transformation is just-in-time compilation. The goal of this transformation is to improve the performance, parallelise the code and run it on an accelerator. JAX doesn’t compile directly to machine code but rather to an intermediate representation. That intermediate representation is independent from the Python code and the machine code of the accelerator. The XLA compiler will then take the intermediate representation and compile it to efficient machine code.
It is not always easy to decide when and what code you should compile, in order to make optimal use of the compiler, we recommend you to check out the documentation. Later in this blog we’ll go a bit deeper in the design of the compiler and why this makes JAX such a powerful framework.
Now that we learned about the most important transformations, we are ready to put that knowledge into practice. We will implement an MLP from scratch to classify MNIST images and train it super fast on a TPU. Our neural network will have an input layer of 728 input variables, followed by two hidden layers with 512 and 256 neurons respectively and an output layer with a node for each class.
The first thing we will need to do is create a structure that represents our model. As input of our initialisation function we have a list with the number of nodes in each layer of our neural network. We have an input layer that is equal to the number of pixels of an image, followed by two hidden layers with 512 and 256 neurons respectively and an output layer that is equal to the number of classes. We use JAX numpy arrays to initialise the model on the accelerator, avoiding to manually copy that data.
Note that generating random numbers is slightly different from numpy. We want to be able to generate random numbers on parallel accelerators and we need a random number generator that works well with the functional programming paradigm. Numpy’s algorithm to generate random numbers isn’t very suited for these purposes. Check out the JAX design notes and documentation for more information.
Our next step is to write a prediction function that will assign labels to a batch of images. We will use automatic vectorization to transform a function that takes in a single image as input and outputs a label into a function that predicts labels for a batch of inputs. Writing a prediction function is not super hard, we flow through the hidden layers of the network and apply weights and biases via a matrix multiplication and vector addition and apply the RELU activation function. At the end we calculate the output label using the RealSoftMax function. Once we have our function to label a single image we can transform it using vmap so it can process a batch of inputs.
The loss function takes a batch of images and calculates the mean absolute error. We call our batched predictions and calculate the label for each image, compare this against the one-hot encoded ground truth labels and calculate the mean number of errors.
Now that we have our prediction and loss function, we will implement an update function to iteratively update our model in each training step. Our update function takes in a batch of images and its ground truth labels together with the current model and a learning rate. We calculate both the value of the loss and the value of its gradient. We update the model using the learning rate and the loss gradients. As we want to compile this function, we have to convert the updated model to a pytree. We also return the value of the loss for monitoring the accuracy.
Now that we have the update function, we are going to compile so it can run on a TPU and greatly improve its performance. The nested functions called in the update function will also be compiled and optimised. The reason why we only apply the compile transformation on update and not on each function separately is that we want to give the compiler as much information to work with, so it can optimize the code as much as possible.
We can define an accuracy function (and optionally other metrics) and create a training loop using our update function and initial model as input. We are now ready to train our model using a TPU or GPU.
Phew, we have learned a lot today. First we started describing JAX as a framework with composable function transformations. The four core transformations are automatic vectorisation, automatic parallelisation over multiple accelerators, automatic differentiation of python functions and JIT-compiling functions to run them on accelerators. We went deeper in on the inner workings of JAX and learned how it is able to create such efficient functions that work both on GPUs and TPUs by compiling to an IR that is then transformed to XLA calls. This approach allows researchers to experiment with new machine learning techniques without having to worry about a low level, highly optimised version of their code. We hope software engineers are excited as well so that new libraries can be built on top of JAX and potential accelerators can be quickly adopted.