Blog
SPMD in JAX #2: Transformers in Bare-Metal JAX
In this second post in a series on programming with JAX for training models like transformers, we write and train a transformer with a decent set of bells and whistles, then benchmark and scale it as much as we can on our TPU v4 half-cube!
Read more...SPMD in JAX #1: Sharding
This post will be the first in a series on programming with JAX, for training models like transformers. I’m experimenting with these as an alternative to my usual scratch paper or LaTeX notes while learning, in the hope that it will help me with recall and perhaps be useful to others learning this material.
Read more...