A Faster Manifold Muon with ADMM

In a recent blog post, Jeremy Bernstein (Thinking Machines) gave an algorithm for optimizing matrix-valued parameters—say, weights in a transformer—that ensures that both the update directions and the parameters themselves are well-conditioned (Bernstein, 2025). This algorithm, called manifold Muon, has an inner loop that is relatively slow to converge, and requires some hyperparameter tuning to get right.

Read more...

SPMD in JAX #2: Transformers in Bare-Metal JAX

In this second post in a series on programming with JAX for training models like transformers, we write and train a transformer with a decent set of bells and whistles, then benchmark and scale it as much as we can on our TPU v4 half-cube!

Read more...

SPMD in JAX #1: Sharding

This post will be the first in a series on programming with JAX, for training models like transformers. I’m experimenting with these as an alternative to my usual scratch paper or LaTeX notes while learning, in the hope that it will help me with recall and perhaps be useful to others learning this material.

Read more...