Notes for learning JAX
As a learning exercise, I recently implemented simple-jax-nn, a simple neural net for the MNIST dataset, written with JAX. This post is my notes on using JAX (as a PyTorch / NumPy user) from working through the JAX docs and a NN tutorial. Simple JAX neural net I managed to get it to use my MacBook’s GPU/Accelerator through the Metal support in jax-metal. It’s a little bit tricky, but will prove useful in future transformer projects that can make use of the hardware. ...