As a learning exercise, I recently implemented simple-jax-nn, a simple neural net for the MNIST dataset, written with JAX.
This post is my notes on using JAX (as a PyTorch / NumPy user) from working through the JAX docs and a NN tutorial.
Simple JAX neural net
I managed to get it to use my MacBook’s GPU/Accelerator through the Metal support in jax-metal. It’s a little bit tricky, but will prove useful in future transformer projects that can make use of the hardware.
uv run simple_nn.py 
Metal device set to: Apple M2 Max
DEVICE:METAL:0
Retrieving dataset...
Building generator...
Constructing train set...
Retrieving test dataset...
Constructing test set...
X_train.shape=(60000, 784)	y_train.shape=(60000, 10)
X_test.shape=(10000, 784)	y_test.shape=(10000, 10)
TRAINING
Initial training set accuracy 0.089
Initial test set accuracy 0.090
[0/20] (1.26s)	Train acc: 0.943	Test acc: 0.941
[1/20] (1.12s)	Train acc: 0.965	Test acc: 0.959
[2/20] (1.12s)	Train acc: 0.975	Test acc: 0.966
[3/20] (1.11s)	Train acc: 0.981	Test acc: 0.971
[4/20] (1.06s)	Train acc: 0.985	Test acc: 0.973
[5/20] (1.06s)	Train acc: 0.988	Test acc: 0.975
[6/20] (1.06s)	Train acc: 0.990	Test acc: 0.976
[7/20] (1.06s)	Train acc: 0.992	Test acc: 0.976
[8/20] (1.06s)	Train acc: 0.994	Test acc: 0.977
[9/20] (1.07s)	Train acc: 0.995	Test acc: 0.977
[10/20] (1.06s)	Train acc: 0.996	Test acc: 0.978
[11/20] (1.06s)	Train acc: 0.996	Test acc: 0.977
[12/20] (1.07s)	Train acc: 0.997	Test acc: 0.978
[13/20] (1.06s)	Train acc: 0.998	Test acc: 0.978
[14/20] (1.07s)	Train acc: 0.998	Test acc: 0.978
[15/20] (1.06s)	Train acc: 0.999	Test acc: 0.978
[16/20] (1.06s)	Train acc: 0.999	Test acc: 0.979
[17/20] (1.20s)	Train acc: 1.000	Test acc: 0.980
[18/20] (1.20s)	Train acc: 1.000	Test acc: 0.980
[19/20] (1.09s)	Train acc: 1.000	Test acc: 0.981
DONE in 23.01s
systemMemory: 96.00 GB
maxCacheSize: 36.00 GB
JAX notes
My notes from reading through Quickstart: How to think in JAX — JAX documentation and subsequent docs.
JAX compiles down to XLA (accelerated linear algebra) to work across all sorts of underlying hardware. jax.numpy is the numpy-like high-level API designed to be familiar and friendly. Under the hood, it’s calling the stricter jax.lax API, which in turn compiles to XLA.
- JAX provides a NumPy-inspired interface for convenience.
- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.
- Unlike NumPy arrays, JAX arrays are always immutable.
JAX is more low-level than PyTorch or Keras, so many common ML tasks and the bookkeeping of training aren’t abstracted away. So follow the cookbook’s recipes: The Training Cookbook — JAX documentation
For convenience, JAX provides jax.numpy which closely mirrors the NumPy API and provides easy entry into JAX. Almost anything that can be done with numpy can be done with jax.numpy, which is typically imported under the jnp alias:
import jax.numpy as jnp
x.devices() where x is a jax.Array tells you on which hardware array contents are stored.
By default, JAX is an eager executer - sending each op off to XLA in sequence. But this is often slow. Using jax.jit(<function>) or @jit decorator allows just-in-time compilation, which comes with a whole host of performance tricks like fusing ops that massive improve performance.
The only constraint is that JIT requires all arrays to have static shapes, failing otherwise.
make_jaxpr(f)(x, y) allows you to see the JAX expression (jaxpr), revealing the tracer objects that JIT mode uses to track the sequence of operations. But it depends only on the type and shape, not the values, so value-dependent conditional code will break.
f there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation with functools.partial.
JAX provides automatic differentiation via the jax.grad transformation.
The jax.jacobian transformation can be used to compute the full Jacobian matrix for vector-valued functions
jax.vmap(<function>) is the vectorizing map, which automatically transforms the function into a batch-aware version. This means you can write inefficient for loops, then just slap vmap around the function and it should get way more efficient.
One major difference with numpy is random numbers in JAX.
Numpy uses global random seeds, but this creates thread-safety issues when adapted to JAX. So instead, you must manually keep track of your jax.random.key(). When you feed the key to a random function like random.normal(key) it is consumed but not modified.
from jax import random
key = random.key(43)
print(random.normal(key))
print(random.normal(key))
0.07520543
0.07520543
The rule of thumb is: never reuse keys (unless you want identical outputs). In order to generate different and independent samples, you must jax.random.split the key explicitly before passing it to a random function with new_key, subkey = random.split(key)
🔪 JAX - The Sharp Bits 🔪 — JAX documentation
- Only use JAX with “functionally pure” Python functions. No printing, no globals, no side effects, no iterators.
- Do not mutate arrays in place. Use the provided atproperty (withset,add, etc.) likex.at[idx].set(y)which creates and returns a new array (but JIT compiler should do some clever things to make this less bad).
- Out-of-bound indexes DO NOT throw errors on accelerator hardware. They silently make weird decisions about how to handle it and carry on.
- Only give JAX array inputs. In numpy you can do np.sum([1, 2, 3]), but this would create individual JAX objects for each element, so you must dojnp.sum(jnp.array([1,2,3]))instead.
- Random numbers (see above).
- Control flow: When executing eagerly (outside of jit), JAX code works with Python control flow and logical operators just like Numpy code. Using control flow and logical operators with jit is more complicated. jax.laxprovides some tools for this.
- Never change the shape of arrays.
- Debug your NaNs with JAX_DEBUG_NANS=Trueorjax.config.update("jax_debug_nans", True)
- JAX doesn’t promote to 64-bit floats unless you explicitly set jax_enable_x64config variable.