Blog
SPMD in JAX #1: Sharding
This post will be the first in a series on programming with JAX, for training models like transformers. I’m experimenting with these as an alternative to my usual scratch paper or LaTeX notes while learning, in the hope that it will help me with recall and perhaps be useful to others learning this material.
Read more...