SPMD in JAX #2: Transformers in Bare-Metal JAX
In this second post in a series on programming with JAX for training models like transformers, we write and train a transformer with a decent set of bells and whistles, then benchmark and scale it as much as we can on our TPU v4 half-cube!
As in previous posts in the series, we make good use of the Google scaling playbook (Austin et al., 2025), as well as some useful JAX tutorials, especially (JAX Team, 2025).
- Setup
- Design Choices
- Building the Model: First Cut
- Sampling from the Model: First Cut
- Building the Model: Attention and Positional Encodings
- Building the Model: Sampling with a Cache
- Conclusions
- Accumulated Configs
- Acknowledgments
Setup
As usual, we’ll test things out in Python 3.13 and JAX 0.7.1 on a single TPU v4 host.
import jax
import jax.numpy as jnp
jax.devices()
Output:
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
Design Choices
A project like this is for learning and pedagogy—it’s inevitably behind the research/infra frontiers. So the code we write is going to lean into this somewhat. Notably,
We’ll implement everything in bare-metal JAX, rather than using neural network libraries like Flax/nnx/Equinox.
This will let us expose and understand all the implementation details, both for infra and architecture, that we might normally overlook. On the flip side, we’ll take some performance hits for this (which we’ll attempt to profile and characterize), and it will lead to some slightly unfortunate code repetition.1 We’ll also be reinventing the wheel throughout – but for pedagogy’s sake!
Model Architecture Overview
Our approach to implementing the network architecture is a modification of the approach used in the JAX training cookbook (JAX Team, 2025), which uses a data structure like those in Google’s ML Collections library to store the model. At a high level, this approach consists of these standard JAX neural net patterns:
- Storing the model parameters as a pytree.
- Defining separate functions for initialization and the model’s forward pass which take
such a pytree as input (as well as other relevant inputs).
This pattern promotes functionally-pure code for the model’s initialization and forward pass,
making it easy to
jit
for performance.
Our implementation tries to move this slightly in the direction of Equinox, without all the powerful features that library entails. In particular, we will:
- Use
NamedTuples
for parameters and layers (i.e., where you’d normally expect to use a
nn.Module
in Pytorch), enabling good static type checking with e.g. Pyright. NamedTuples are automatically registered as pytrees in JAX! - Write a top-level function for model initialization, and layer-level functions for each separate layer’s operation (which, as above, take parameters and inputs as arguments). This is a ‘bottom-up’ version of the previous design, where we’ll end up with many different functions, one for each ‘layer’, which are combined bottom-up to implement the overall model’s forward pass (more like in a typical neural network library).
Building the Model: First Cut
When writing this post, I wrote and debugged the initial model/training code iteratively in a notebook. Rather than just jumping to the ‘final answer’, we’ll walk through the development process below.
In this part of the post, we’ll build up to the transformer through a few simple sub-modules of the overall model. Our targets will be:
- We’ll work with a very simple data distribution, described below, which requires minimal dataloading code (and no tokenization).
- We’ll focus on the training loss, leaving sampling for later. Although simple sampling gives us an important correctness test (“did we accidentally give the model access to future tokens?”, for example…), doing it ‘fast’ requires a bit of code, so we’ll save it for later.
On the way to these targets, we’ll build up the following:
- Just the MLP, giving us a position-wise model (bigrams with a nonlinearity!).
- A training loop with SGD and cross-entropy (next token prediction) loss.
- Just causal attention, letting us build a transformer that can solve our toy task (discussed in the next section) with two layers.
- Rotary positional encodings, so that we can solve our task with a one-layer model.
We’ll then step back and refactor things with an eye towards training larger models on actual language data in a distributed setting.
If you’re following along and you notice anything I’ve implemented that seems suboptimal, please let me know! Drop me an email or a DM on Twitter (links at the bottom of the page).
A Quick Note on Configuration
For overall convenience, we eventually want to combine all our different configuration options in a large dataclass, then pass our runtime instantiation of the config (with defaults, overrides, etc.) to all our functions, which can extract just the parameters they need.
When writing this code, I tend to define the config options I need locally, then go back later and turn them into attributes of the overall config dataclass. To avoid frontloading with all the complexity of the config, and to avoid too much unnecessary description of the refactoring process, we’ll reference different config options below as attributes of an (as of yet undefined) Config class. Code outputs will also have this class defined. You can Ctrl+F for type definitions and defaults, which will appear later, if it feels necessary.
Brief Setup
We are going to make reference to different sources of randomness below, which we seed from our Config class.
# Test it out
config = Config()
key = jax.random.key(config.seed)
key_params, key_data, key_sampling = jax.random.split(key, 3)
Data
The data will consist of an (infinite) sequence of consecutive digits (from 0 to 9), which ‘reflect’ when they reach 0 or 9. I.e.,
01234567898765432101...
We’ll generate a dataset of fixed-length subsequences of this infinite sequence for training.
text = "012345678987654321" * 1024
seqs = [
[int(c) for c in text[i : i + config.seq_len + 1]]
for i in range(len(text) - config.seq_len - 1)
]
data = jnp.array(seqs, dtype=jnp.int32)
key_data, sk = jax.random.split(key_data)
data = jax.random.permutation(sk, data, axis=0)
num_data = len(data)
print(data[128])
Xtr = data[: int(0.8 * num_data)]
Xdev = data[int(0.8 * num_data) : int(0.9 * num_data)]
Xte = data[int(0.9 * num_data) :]
Xtr.shape
Output:
[8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8
9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9
8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8
7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7
6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6
5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5
4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6]
(14540, 257)
We shuffle the sequences above and perform a simple train-dev-test split (although we’ll just use train for now). We take one more character than the sequence length, so that we can easily turn these raw sequences into inputs and targets. We do this with a simple dataloader:
import itertools as it
def dataloader(key, config: Config):
for step in it.count():
key = jax.random.fold_in(key, step)
offsets = jax.random.randint(key, (config.global_batch_size,), 0, num_data)
yield (Xtr.at[offsets, :-1].get(), Xtr.at[offsets, 1:].get())
key_data, sk = jax.random.split(key_data)
batch = map(
lambda batch: jax.device_put(batch, config.sharding_data),
dataloader(sk, config),
)
Here batch
is a map iterator which wraps the dataloader with a device_put
statement to enforce
our desired sharding (e.g., data parallel). When we create our config class, we create a
global mesh and set it as the default mesh, so we can pass a PartitionSpec
object to
device_put
in our single-host setting (see the previous post).
The batches being loaded have shape
(config.global_batch_size, config.seq_len)
, so expect config.sharding_data
to be
for example jax.P('dp')
for a “data-parallel” mesh axis 'dp'
, which leads the batch
to be sharded across accelerators.
Defining the Model: MLP
We’ll implement the overarching structure of our model here, described above in the “Design Choices” section, but keep the low-level implementations restricted to just the MLP.
We’ll follow the following conventions:
- “Layers” correspond to
NamedTuples
(which are pytrees). The elements of such a “layer” can be eitherjax.Array
s (parameters/pytree leaves), or other NamedTuples (for composing layers). - We’ll name our layers with capital letters, and the functions that implement them (i.e., their “forward” method in Pytorch lingo) in lower-case, with a leading underscore.
Here’s our MLP layer following these conventions:
from typing import NamedTuple
class Mlp(NamedTuple):
w_up: jax.Array
bias_up: jax.Array
w_down: jax.Array
bias_down: jax.Array
def _mlp(config: Config, params: Mlp, x: jax.Array):
preact = jnp.dot(x, params.w_up, out_sharding=config.sharding_mlp_hidden)
if config.use_bias_mlp:
preact += params.bias_up
act = jax.nn.gelu(preact, approximate=True)
out = jnp.dot(act, params.w_down, out_sharding=config.sharding_res_stream)
if config.use_bias_mlp:
out += params.bias_down
return out
We would create parameters for such a layer and apply it as follows:
my_mlp = Mlp(
w_up=jnp.zeros((2, 4)),
bias_up=jnp.zeros((4,)),
w_down=jnp.zeros((4, 2)),
bias_down=jnp.zeros((2,)),
)
_mlp(config, my_mlp, jnp.ones((2,)))
Output:
Array([0., 0.], dtype=float32)
It’s important to note here that we are not able to enforce shape information at
this level of implementation, since abstractly speaking an MLP can be implemented for
any compatible weight matrix dimensions. These shapes get enforced later, when we
initialize an Mlp
with settings specified in our config
instance.
More importantly, there is some ambiguity in the types of the arguments x
and params
that _mlp
will apply to – e.g. the above function could apply to x
being a vector
of shape (d_model,)
, or a matrix of shape (seq_len, d_model)
, as long as the
parameters of Mlp
are initialized correctly. In other words, it’s up to us to make
sure we apply this implementation correctly.
In our subsequent implementation, we’re going to follow a helpful principle that JAX
affords – we’ll operate at the finest level of granularity possible at every stage of
the model, and ‘move up’ to the next level of granularity as necessary, using JAX’s
powerful primitives like vmap
and scan
. In particular, we’ll see soon that we’ll
only use the above _mlp
function in cases where x
is a vector of shape (d_model,)
.
Since the MLPs in transformers operate the same on every position in the sequence,
and every sequence in the batch, we’ll just vmap
our _mlp
function when we need to
use it at the next level of granularity! This is a pretty empowering design strategy
that JAX enables: it lets us write our code’s core numerical functionality without
having to plan ahead for exactly how it will be used later, making implementation
simpler and separating complexity nicely.
A few other notes (some drawbacks of this bare metal approach, which seem unvaoidable):
- There’s no ‘state’ in the
_mlp
function – it just performs the operation of a MLP given parameters/inputs. This makes it easy tojit
. - Initialization is not handled here – we’ll have to do it later, since we want to initialize the entire transformer (top-down) in one shot.
- We enforce shardings for the hidden activations/output of the MLP in the
_mlp
function. We enforce shardings for the parameters of the MLP when we initialize them (see later).
Defining the Model: Embeddings and Placeholders
To have a basic working model, we need token embeddings (mapping the raw integer tokens we generated above to vectors the transformer can operate on), unembeddings (mapping the vector outputs of the transformer to logits, which we can map to a probability distribution over tokens to predict with), and normalization layers. We’ll have attention layers too, eventually; for now, we’ll make a placeholder function for the attention layer.
class Attn(NamedTuple):
w_qkv: jax.Array
w_o: jax.Array
def _attn(
config: Config,
params: Attn,
x_seq: jax.Array,
kv_cache: jax.Array,
cache_size: int,
):
return x_seq, kv_cache
class Embedding(NamedTuple):
w: jax.Array
def _embedding(config: Config, params: Embedding, token: jax.Array):
emb = params.w.at[token].get(out_sharding=config.sharding_res_stream)
return emb
class Unembedding(NamedTuple):
w: jax.Array
def _unembedding(config: Config, params: Unembedding, x: jax.Array):
logits = jnp.dot(x, params.w, out_sharding=config.sharding_res_stream)
return logits
class LayerNorm(NamedTuple):
gamma: jax.Array
beta: jax.Array
def _layernorm(config: Config, params: LayerNorm, x: jax.Array):
x_std = jax.nn.standardize(x.astype(config.compute_dtype), epsilon=config.eps_ln)
out = params.gamma * x_std.astype(config.param_dtype)
if config.use_bias_ln:
out += params.beta
return out
These are standard unoptimized implementations for these different building blocks.
Looking ahead slightly, we make sure to provide the ability to compute the layer
normalization operation in a higher precision (config.compute_dtype
), since the
variance operation in standardize
can lead to underflow and other losses of precision.
The API for _attn
involves arguments for utilizing a cache for faster inference (we’d
normally leave this out to begin with, and add it only later).
Defining the Model: Building Blocks to Transformer
A transformer consists of an embedding, then a sequence of transformer blocks (which
involve normalization + attention and a residual connection, then normalization + MLP
and a residual connection), then the unembedding.
Accordingly, we make a Block
class which combines our low-level building blocks.
from functools import partial
class Block(NamedTuple):
norm_attn: LayerNorm
attn: Attn
norm_mlp: LayerNorm
mlp: Mlp
def _block(
config: Config,
params: Block,
x_seq: jax.Array,
cache_in: jax.Array,
cache_size: int,
):
att_skip = x_seq
out = jax.vmap(partial(_layernorm, config, params.norm_attn))(x_seq)
out, cache_out = _attn(
config, params.attn, out, kv_cache=cache_in, cache_size=cache_size
)
out += att_skip
mlp_skip = out
out = jax.vmap(partial(_layernorm, config, params.norm_mlp))(out)
out = jax.vmap(partial(_mlp, config, params.mlp))(out)
out += mlp_skip
return out, cache_out
Notice that we’re following the design principle we mentioned above with the MLP: here
we expect x_seq
to have two dimensions, the first for the sequence length and the
second for the model dimension, and we vmap
the _mlp
and _layernorm
functions to
apply to it correctly.
We can now define the transformer. We’ll do this in a slightly refined way versus just
using a for loop: we’ll expect the blocks
parameter below, of type Block
, to have a
mapped leading axis corresponding to the layer dimension, and we’ll consume this
leading axis using a scan
. That means we expect every parameter in the Block
pytree
we pass to not just have its usual shape, say param.shape
(expected by the _mlp
,
etc. APIs above), but additionally be of shape (num_layers,) + param.shape
. This turns
out to be very easy to guarantee with a vmap
at initialization (see below).
If you aren’t familiar with scan
, you can read
about the API here.
It takes a function that accepts two arguments – a ‘carry’ and an ‘input’ – and
produces one output, as well as an ‘initial’ carry, and an initial input, which must
have an additional mapped axis somewhere relative to the input shape the function
argument expects. It will then iteratively apply the function to the current carry and
the current input; the output carry is passed to the next call of the function. In this
way, it’s like a vectorized for loop – you can also draw a direct analogy to a
recurrent neural network, or other dynamical systems.
In our case, the ‘carry’ for scan is our network’s input and activations (we pass these
from layer to layer sequentially), and the ‘input’ for scan is the mapped Block
pytree
(as well as the mapped KV cache for inference – the updated cache becomes the output –
but we can ignore that for now).
class Transformer(NamedTuple):
emb: Embedding
blocks: Block # vmapped at init
unemb: Unembedding
def _transformer(
config: Config,
params: Transformer,
tokens: Array,
cache_in: jax.Array,
cache_size: int = 0,
):
x_seq = jax.vmap(partial(_embedding, config, params.emb))(tokens)
def _block_fun(x_seq: Array, params__cache_in: tuple[Block, jax.Array]):
params, cache_in = params__cache_in
return _block(config, params, x_seq, cache_in, cache_size)
out, cache_out = jax.lax.scan(_block_fun, x_seq, (params.blocks, cache_in))
out = jax.vmap(partial(_unembedding, config, params.unemb))(out)
return out, cache_out
Using scan
leads to a very concise forward pass definition for the transformer. Note
that our transformer is still operating on sequences of embeddings – we’ll apply it to
a batch of sequences with vmap
later!
Defining the Model: Initialization
The last step before testing is to initialize the model. This leads to a long definition, but the intrinsic complexity is not high.
def init_model_params(key, config: Config) -> Transformer:
def init_embedding(key) -> Embedding:
emb = config.param_std * jax.random.normal(
key,
(config.num_vocab, config.d_model),
config.param_dtype,
out_sharding=config.sharding_res_stream,
)
return Embedding(w=emb)
def init_unembedding(key) -> Unembedding:
unemb = config.param_std * jax.random.normal(
key,
(config.d_model, config.num_vocab),
config.param_dtype,
out_sharding=config.sharding_res_stream,
)
return Unembedding(w=unemb)
def init_mlp(key) -> Mlp:
k_w_up, k_w_down = jax.random.split(key, 2)
w_up = config.param_std * jax.random.normal(
k_w_up,
(config.d_model, config.mlp_factor * config.d_model),
config.param_dtype,
out_sharding=config.sharding_wup,
)
w_down = config.param_std * (
jax.random.normal(
k_w_down,
(config.mlp_factor * config.d_model, config.d_model),
config.param_dtype,
out_sharding=config.sharding_wdown,
)
)
bias_up = jnp.zeros(
(config.mlp_factor * config.d_model,),
config.param_dtype,
out_sharding=config.sharding_mlp_hidden,
)
bias_down = jnp.zeros(
(config.d_model,),
config.param_dtype,
out_sharding=config.sharding_res_stream,
)
return Mlp(w_up=w_up, bias_up=bias_up, w_down=w_down, bias_down=bias_down)
def init_attn(key) -> Attn:
k_qkv, k_o = jax.random.split(key, 2)
w_qkv = config.param_std * jax.random.normal(
k_qkv,
(config.d_model, 3, config.num_heads, config.d_head),
config.param_dtype,
out_sharding=jax.P(*config.sharding_wqkv),
)
w_out = config.param_std * (
jax.random.normal(
k_o,
(config.d_head, config.num_heads, config.d_model),
config.param_dtype,
out_sharding=jax.P(*config.sharding_wo),
)
)
return Attn(w_qkv=w_qkv, w_o=w_out)
def init_layernorm() -> LayerNorm:
gamma = jnp.ones(
(config.d_model,),
config.param_dtype,
out_sharding=config.sharding_res_stream,
)
beta = jnp.zeros(
(config.d_model,),
config.param_dtype,
out_sharding=config.sharding_res_stream,
)
return LayerNorm(gamma=gamma, beta=beta)
def init_block(key) -> Block:
key_attn, key_mlp = jax.random.split(key)
return Block(
norm_attn=init_layernorm(),
attn=init_attn(key_attn),
norm_mlp=init_layernorm(),
mlp=init_mlp(key_mlp),
)
# Make the full network
key_emb, key_blocks, key_unemb = jax.random.split(key, 3)
keys_blocks = jax.random.split(key_blocks, config.num_layers)
return Transformer(
emb=init_embedding(key_emb),
blocks=jax.vmap(init_block)(keys_blocks),
unemb=init_unembedding(key_unemb),
)
Notice above how we produce the mapped axis needed for the scan
in the _transformer
function above: we just split our PRNGKey into an array of keys, one for each layer, and
vmap
the init_block
function over the array of keys. Easy!
Testing the Basic Model
We can perform a quick test to make sure everything above is at least syntactically correct. There are two things we want to see:
- Can we run the model in eager mode?
- Can we
jit
the model?
We really only care about 2 (it implies 1), so let’s test it.
key_params, sk = jax.random.split(key_params)
tf = init_model_params(sk, config)
print("Number of layers in model: ", config.num_layers)
print("Model dimension: ", config.d_model)
cache = None
tokens = jnp.arange(10)
@jax.jit
def model(params, tokens):
return _transformer(config, params, tokens, cache, 0)
out, out_cache = model(tf, tokens)
out
Output:
Number of layers in model: 2
Model dimension: 768
Array([[0.894531, -0.482422, 1.41406, 0.194336, -2.76562, -1.27344,
1.21094, 0.417969, -0.460938, -0.304688],
[0.0610352, 2.8125, 1.08594, -0.351562, -0.326172, -1.46094,
0.0354004, 0.283203, 0.796875, -1.42969],
[0.186523, 1.21875, 0.429688, -1.10156, 0.667969, -2.07812,
-0.078125, 0.410156, 0.0273438, -2.60938],
[0.201172, 1.46094, 0.124512, -1.27344, -1.05469, -0.667969,
0.453125, 1.78906, -1.00781, 1.67188],
[-0.925781, 0.988281, -1.14844, 0.0942383, 0.0898438, 1.70312,
1.125, 1.03125, -1.17969, -1.69531],
[-0.710938, -0.679688, -1.53125, 0.800781, 0.443359, -0.71875,
-2.29688, 0.488281, -0.169922, -1.38281],
[1.44531, 0.949219, -0.589844, 0.357422, -1.30469, 2.85938,
-0.388672, 0.46875, -0.695312, -0.425781],
[-1.85156, 0.679688, 1.25, 0.109863, -0.667969, -0.361328,
-0.0405273, -0.964844, 3.01562, 0.333984],
[-0.304688, 1.83594, 0.734375, 1.16406, -0.114746, -0.800781,
0.486328, 1, -1.00781, -1.01562],
[-1.66406, 0.925781, -0.0908203, 1, 0.298828, -2.40625, 0.574219,
0.945312, -2.54688, -0.123047]], dtype=bfloat16)
Looks good! There might be a small issue with the unembedding scaling, as we can see:
jax.nn.softmax(out, axis=-1)
Output:
Array([[0.163086, 0.0410156, 0.275391, 0.0810547, 0.00418091, 0.0186768,
0.224609, 0.101562, 0.0419922, 0.0493164],
[0.0390625, 0.613281, 0.108887, 0.026001, 0.0264893, 0.00848389,
0.0380859, 0.0488281, 0.081543, 0.00872803],
[0.100098, 0.279297, 0.126953, 0.02771, 0.161133, 0.010376,
0.0766602, 0.124512, 0.0849609, 0.006073],
[0.0583496, 0.204102, 0.0539551, 0.0133057, 0.0164795, 0.0244141,
0.0751953, 0.285156, 0.017334, 0.253906],
[0.0227051, 0.15332, 0.0181885, 0.0629883, 0.0629883, 0.314453,
0.176758, 0.160156, 0.0177002, 0.010437],
[0.0588379, 0.060791, 0.026123, 0.267578, 0.1875, 0.0588379,
0.012146, 0.195312, 0.101562, 0.0300293],
[0.141602, 0.0864258, 0.0184326, 0.0476074, 0.00909424, 0.582031,
0.022583, 0.0534668, 0.0164795, 0.0218506],
[0.00500488, 0.0629883, 0.112305, 0.0358887, 0.0164795, 0.0224609,
0.0307617, 0.012207, 0.65625, 0.0444336],
[0.0395508, 0.335938, 0.111328, 0.171875, 0.0473633, 0.0239258,
0.0869141, 0.145508, 0.0194092, 0.0194092],
[0.0145874, 0.193359, 0.0693359, 0.208008, 0.102539, 0.00689697,
0.135742, 0.196289, 0.00598145, 0.0673828]], dtype=bfloat16)
In particular, the initial outputs are not very uniform, which we would like to have. However, we should still be able to learn – our network is not very deep.
A Minimal Training Loop
Now we have enough to train a model. A simple training loop follows, using SGD.
First, we set up the training step, jit
ting it for performance.
class TrainState(NamedTuple):
params: Transformer
opt: None
@jax.jit
def init_train_state(key, config: Config) -> TrainState:
return TrainState(params=init_model_params(key, config), opt=None)
train_state = init_train_state(key_params, config)
cache = None
@partial(jax.jit, donate_argnums=2)
def train_step(config: Config, batch, train_state: TrainState):
def loss_fn(params: Transformer):
inputs, targets = batch
logits, cache_out = jax.vmap(partial(_transformer, config, params))(
inputs, cache
)
logits = logits.astype(config.compute_dtype)
logprobs = jax.nn.log_softmax(logits, axis=-1)
return -jnp.take_along_axis(logprobs, targets[..., None], axis=-1).mean()
loss, grad = jax.value_and_grad(loss_fn)(train_state.params)
new_state = TrainState(
params=jax.tree.map(lambda x, y: x - config.lr * y, train_state.params, grad),
opt=None,
)
metrics = {"loss": loss}
return metrics, new_state
In the training step, after loading a batch of sequences from our previously-written
batch loader, we vmap _transformer
over it, then calculate the next-token prediction
loss with the standard bare-metal approach (with possible up-casting for numerical
stability of the softmax). The donate_argnums
argument to jit
lets us mark that the
train_state
buffer can be discarded after the step finishes (since we create and
return a new TrainState
with the results of the SGD step), letting the XLA
compiler
reuse that memory for the output state.
The loop is then simple to write:
prev_metrics = None
for step in range(config.num_steps):
cur_metrics, train_state = train_step(config, next(batch), train_state)
log_metrics, prev_metrics = prev_metrics, cur_metrics
if log_metrics and step % 50 == 0:
log_metrics |= {"step": step}
print(*[f"{metric}: {val}" for metric, val in log_metrics.items()], sep="\t")
Output:
loss: 1.1560431718826294 step: 50
loss: 0.9151328206062317 step: 100
loss: 1.004440188407898 step: 150
loss: 0.8609314560890198 step: 200
loss: 0.9039498567581177 step: 250
loss: 0.8058883547782898 step: 300
loss: 0.7688642144203186 step: 350
loss: 0.6995882987976074 step: 400
loss: 0.6380590796470642 step: 450
loss: 0.6304700374603271 step: 500
loss: 0.6276624798774719 step: 550
loss: 0.6255137920379639 step: 600
loss: 0.6255109310150146 step: 650
loss: 0.6271587610244751 step: 700
loss: 0.6273760795593262 step: 750
loss: 0.6262620687484741 step: 800
loss: 0.6236293911933899 step: 850
loss: 0.6241434216499329 step: 900
loss: 0.6261416673660278 step: 950
It’s not great, but it’s indeed training! We don’t expect to be able to solve this task
perfectly with our simple model, since a position-wise MLP can only use one token of
context to predict the next token, and our task requires at least two. To see why, just
note that, for example, the character following 5
could be either 6
or 4
,
depending on whether we’re at a rising or falling part of the sequence. Two characters
of context is enough to estimate this, meaning adding attention should help here!
As an aside, above, we make sure to print metrics for the previous iteration of the training loop at each step. This helps with correctly pipelining host to data transfers: see (JAX Team, 2025) for a detailed description.
Sampling from the Model: First Cut
To get a sense of the behavior of our learned model, we can write our sampler for it now. Normally, efficient sampling requires a cache of past outputs to be maintained (we already have the API for this), but we’ll discuss this later after we implement attention.
@partial(jax.jit, donate_argnums=(4,))
def sample_one_token(
config: Config,
key,
params: Transformer,
x: jax.Array,
cache_in: jax.Array,
cache_size: int,
temperature: float,
):
y, cache_out = _transformer(config, params, x, cache_in, cache_size)
logits = y.astype(config.compute_dtype)
cache_size = cache_size + x.shape[-1]
next_token = jnp.array((jax.random.categorical(key, logits[-1] / temperature),))
return next_token, cache_out, cache_size
config_sampling_args = {"sharding_data": jax.P()}
config_sampling = Config(**config_sampling_args)
output = jnp.array((1,))
cache = None
tokens_to_generate = 63
temperature = 0.1
for step in range(tokens_to_generate):
key_sampling, sk = jax.random.split(key_sampling)
next_token, cache, cache_size = sample_one_token(
config_sampling, sk, train_state.params, output, cache, 0, temperature
)
output = jnp.concatenate((output, next_token))
output
Output:
Array([1, 2, 3, 2, 3, 4, 3, 2, 3, 2, 3, 5, 6, 5, 4, 3, 4, 5, 2, 1, 2, 3,
2, 1, 0, 1, 2, 3, 4, 3, 4, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 5, 6, 5,
6, 5, 1, 2, 0, 1, 0, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 4, 3, 2], dtype=int32)
We set a relatively high temperature in order to better highlight what the model has learned. The generated sequence is in line with our hypothesis – the model has lowered the loss by learning that it needs to predict the number that’s one above or one below the current input, but it can’t do much better than guessing randomly between these two possibilities (hence why our final loss is close to $\log(2)$).
Let’s improve this with attention!
Building the Model: Attention and Positional Encodings
An attention implementation that follows jax.nn.dot_product_attention
shape
conventions follows, with both manual implementations and one using this function.
For now, we omit the cache implementation and positional embedding implementation.
def _attn(
config: Config,
params: Attn,
x_seq: jax.Array,
kv_cache: jax.Array,
cache_size: int,
):
# s: sequence length
# d: embedding dim (config.d_model)
# n: attention heads (config.num_heads)
# h: head dim (config.d_head)
# x_seq: s x d
qkv = jnp.einsum(
"sd,d3nh->3snh",
x_seq,
params.w_qkv,
out_sharding=jax.P(*config.sharding_att_qkv),
)
q, k, v = [qkv[i] for i in range(3)]
s = q.shape[0]
# Attention computation
t = k.shape[0]
mask = (cache_size + jnp.arange(s))[:, None] >= jnp.arange(t)[None, :]
mask = mask[None, ...] # broadcast over heads
if config.use_fa:
attn_out = jax.nn.dot_product_attention(q, k, v, scale=None, mask=mask)
else:
logits = jnp.einsum("snh,tnh->nst", q, k).astype(config.compute_dtype)
# Scale and causal mask
logits *= 1.0 / config.d_head**0.5
logits = jnp.where(mask, logits, -jnp.inf)
probs = jax.nn.softmax(logits, axis=2) # type: ignore[reportArgumentType]
probs = probs.astype(config.param_dtype)
attn_out = jnp.einsum("nst,tnh->snh", probs, v)
out = jnp.einsum(
"snh,hnd->sd",
attn_out,
params.w_o,
out_sharding=jax.P(*config.sharding_res_stream),
)
return out, kv_cache
Attention takes the embeddings of the input sequences $\vX \in \bbR^{S \times D}$ as input. It computes $N$ “attention heads”, then combines them to form the output. Each head, indexed by $n \in [N]$, is generated from different projections of the input embedding sequences: we calculate
\[\begin{equation*} \vQ_n = \vX \vW_{Q, n}, \quad \vK_n = \vX \vW_{K, n}, \quad \vV_n = \vX \vW_{V, n}, \end{equation*}\]where the output dimension of the projections is $H$, then generate the head output as
\[\begin{equation*} \vA_n = \mathrm{softmax}(\vQ_n \vK_n^\top + \vM) \vV_n, \end{equation*}\]with the softmax being taken along rows, and $\vM$ denoting a causal mask:
\[M_{ij} = \begin{cases} 0 & i \leq j \\ -\infty & \ow. \end{cases}\]The heads are then concatenated together along the embedding dimension, producing an $S
\times NH$ dimensional result (usually $NH = D$), and then an output projection $\vW_O$
is applied to produce the attention output, which has the same shape as the input $\vX$.
The code above implements these operations concisely with jnp.einsum
.
We can replace our previous placeholder attention implementation with a complete one and then rerun our training loop, since we already added the proper initialization code above.
Output:
loss: 0.6685549020767212 step: 50
loss: 0.58647620677948 step: 100
loss: 0.5948932766914368 step: 150
loss: 0.5258455872535706 step: 200
loss: 0.5094199776649475 step: 250
loss: 0.40264642238616943 step: 300
loss: 0.42422279715538025 step: 350
loss: 0.48532211780548096 step: 400
loss: 0.45087528228759766 step: 450
loss: 0.3302219808101654 step: 500
loss: 0.39742058515548706 step: 550
loss: 0.3917175531387329 step: 600
loss: 0.3619387745857239 step: 650
loss: 0.34305423498153687 step: 700
loss: 0.2451237142086029 step: 750
loss: 0.27099400758743286 step: 800
loss: 0.4500095248222351 step: 850
loss: 0.26931577920913696 step: 900
loss: 0.2725009322166443 step: 950
Output:
loss: 0.15929250419139862 step: 50
loss: 0.14883831143379211 step: 100
loss: 0.165802463889122 step: 150
loss: 0.14946813881397247 step: 200
loss: 0.15420110523700714 step: 250
loss: 0.14391759037971497 step: 300
loss: 0.2680056691169739 step: 350
loss: 0.14069046080112457 step: 400
loss: 0.27090030908584595 step: 450
loss: 0.14207065105438232 step: 500
loss: 0.13676664233207703 step: 550
loss: 0.13658639788627625 step: 600
loss: 0.13079527020454407 step: 650
loss: 0.39731210470199585 step: 700
loss: 0.1367562860250473 step: 750
loss: 0.13199977576732635 step: 800
loss: 0.1323813498020172 step: 850
loss: 0.12948714196681976 step: 900
loss: 0.13398176431655884 step: 950
Two back-to-back output traces are posted above. We can get a sense that this model
should be able to learn to solve the task, but the poor optimizer (constant learning
rate SGD) is holding it back. If we rerun the sampling code above with the new model at
low temperature (temperature=0.1
), we do get correct outputs:
Output:
Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4,
5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8], dtype=int32)
Yet the final loss is suggestive that the model we’ve learned is not very good. We can improve the model further by incorporating positional encodings. We’ll use the modern choice, rotary positional encodings (RoPE). There are already many great blogs explaining the intuition behind these encodings, so I’ll just focus here on the algebraic ideas.
The basic idea is to consider the dot product of a query-key pair forming one head’s attention matrix $\vA$, as we defined it above (dropping the subscript $n$ for concision). Ignoring the mask, for the $ij$-th entry of this attention matrix, we consider the dot product (“attention logits”) $\vq_i \vk_j^\top$. A notable property of attention is that the only dependence of its output on the position $i$ is through the causal mask $\vM$: if we drop the mask, then permuting the input positions produces a corresponding permutation of the output positions, because every position receives the same projections $\vW_{QKV}$.2
This is arguably undesirable: the model benefits in many cases from the ability to learn a sequence-to-sequence mapping that is more position sensitive. For example, in our toy data setting, the model would be able to learn a good mapping faster if it were able to selectively focus on only a small number of past tokens (since this is enough to determine whether we’re in a “rising” or “falling” part of the sequence).
A nice way to incorporate this positional information is through a relative positional embedding. In such a method, we modify the attention operation so that the attention logits incorporate the magnitude of the difference in positions $\abs{i - j}$. In RoPE, this is done by inducing a quadratic form
\[\begin{equation*} \vq_i \vR_{\abs{i-j}} \vk_j^\top, \end{equation*}\]where $\vR_{\abs{i-j}}$ is a specific rotation matrix.3 RoPE’s choice of this matrix has the useful property that for fixed-norm keys and queries, the magnitude of the encoded attention logits decays as a function of $\abs{i-j}$, giving a priori less weight to logits of positions that are further apart.
It also has the essential property that the encoded logits can be computed efficiently. It turns out that there are (real-valued) sequences $\vc_i$, $\vs_i$ corresponding to each position, such that the following holds. Let $\vPi \in \bbR^{H \times H}$ denote the permutation matrix that reverses the top $H/2$ coordinates of a vector with the bottom $H/2$ coordinates (assume $H$ is even).4 Then one has
\[\begin{equation*} \vq_i \vR_{\abs{i-j}} \vk_j^\top = (\vq_i \circ \vc_i + \vPi \vq_i \circ \vs_i) (\vk_j \circ \vc_j + \vPi \vk_j \circ \vs_j)^\top, \end{equation*}\]where matrix multiplication binds before elementwise multiplication $\circ$. In particular, we can compute the encoded logits via a small number of elementwise multiplications, adds, and concatenations involving the raw keys and queries, which leads to a minimal-complexity implementation. The sequences $\vc_i, \vs_i$ depend only on position, and can be precomputed and reused.
We give a simple implementation of RoPE below, and its integration into our attention layer we implemented previously.
def _precompute_rope_sincos(config: Config):
freqs = jnp.exp(
-jnp.log(config.rope_theta).astype(config.compute_dtype)
* 2
/ config.d_head
* jnp.arange(config.d_head // 2, out_sharding=jax.P())
)
positions = jnp.arange(config.max_seq_len, out_sharding=jax.P())
cycles = positions[:, None] * freqs[None, :]
return jnp.cos(cycles), jnp.sin(cycles)
def _apply_rope(
config: Config,
cos: jax.Array,
sin: jax.Array,
positions: jax.Array,
x: jax.Array, # S x H
):
x_1, x_2 = x[:, : config.d_head // 2], x[:, config.d_head // 2 :]
c, s = cos.at[positions].get(), sin.at[positions].get()
return jnp.concatenate(
(c * x_1 - s * x_2, c * x_2 + s * x_1), axis=-1, dtype=config.param_dtype
)
def _attn(
config: Config,
params: Attn,
x_seq: jax.Array,
kv_cache: jax.Array,
cache_size: int,
):
# s: sequence length
# d: embedding dim (config.d_model)
# n: attention heads (config.num_heads)
# h: head dim (config.d_head)
# x_seq: s x d
qkv = jnp.einsum(
"sd,d3nh->3snh",
x_seq,
params.w_qkv,
out_sharding=jax.P(*config.sharding_att_qkv),
)
q, k, v = [qkv[i] for i in range(3)]
s = q.shape[0]
# Apply RoPE
if config.use_rope:
if not config.update_cache:
cache_size = 0 # ignore passed value
with jax.ensure_compile_time_eval():
rope_cos, rope_sin = _precompute_rope_sincos(config)
positions = jnp.arange(s, out_sharding=jax.P())
_apply_rope_one_head = partial(
_apply_rope, config, rope_cos, rope_sin, positions + cache_size
)
_apply_rope_all_heads = jax.vmap(_apply_rope_one_head, in_axes=1, out_axes=1)
q, k = _apply_rope_all_heads(q), _apply_rope_all_heads(k)
# Attention computation
t = k.shape[0]
mask = (cache_size + jnp.arange(s))[:, None] >= jnp.arange(t)[None, :]
mask = mask[None, ...] # broadcast over heads
if config.use_fa:
attn_out = jax.nn.dot_product_attention(q, k, v, scale=None, mask=mask)
else:
logits = jnp.einsum("snh,tnh->nst", q, k).astype(config.compute_dtype)
# Scale and causal mask
logits *= 1.0 / config.d_head**0.5
logits = jnp.where(mask, logits, -jnp.inf)
probs = jax.nn.softmax(logits, axis=2) # type: ignore[reportArgumentType]
probs = probs.astype(config.param_dtype)
attn_out = jnp.einsum("nst,tnh->snh", probs, v)
out = jnp.einsum(
"snh,hnd->sd",
attn_out,
params.w_o,
out_sharding=jax.P(*config.sharding_res_stream),
)
return out, kv_cache
The changes to the _attn
function are all within the code block marked with # Apply
RoPE
(everything else is unchanged); the _apply_rope
function takes care of the
necessary permutation/concatenation operation. We use JAX’s ensure_compile_time_eval
context manager to ensure the RoPE multiplying sequences are precomputed. The
multiplying sequences themselves are chosen to have geometrically-varying time values,
with a rate parameter rope_theta
set to 1e4
by default (in general, set based on the
target context length, here 1024).
With RoPE, we find it much easier to learn a strong model for our toy data. After rerunning training with these modifications, we observe the following loss progression:
Output:
loss: 0.6326758861541748 step: 50
loss: 0.4087381660938263 step: 100
loss: 0.14004486799240112 step: 150
loss: 0.08325601369142532 step: 200
loss: 0.06802305579185486 step: 250
loss: 0.05945558100938797 step: 300
loss: 0.05321500450372696 step: 350
loss: 0.052829645574092865 step: 400
loss: 0.048602353781461716 step: 450
loss: 0.04696694016456604 step: 500
loss: 0.04576549679040909 step: 550
loss: 0.04564153775572777 step: 600
loss: 0.04341501370072365 step: 650
loss: 0.043622132390737534 step: 700
loss: 0.04205687716603279 step: 750
loss: 0.04401983320713043 step: 800
loss: 0.04256521165370941 step: 850
loss: 0.041673485189676285 step: 900
loss: 0.0421551875770092 step: 950
Much better! For a completely overconfident model, we might expect to have loss around
$\log(2) / 256 \approx 0.002$, corresponding to randomly guessing for the first token in
each sequence of the batch (we train with sequence length 256
). This is within about
an order of magnitude.
Building the Model: Sampling with a Cache
The implementation of sampling we gave above is extremely slow: although we jit
the
sample_one_token
function, we are re-forwarding the entire updated sequence through
the model after every generation, which is very wasteful. A quick test on our
previously-implemented sampling function (on 1x TPU v4 host) takes about 50 seconds to
sample 63 tokens, including the time to jit
compile the sampling function.
This makes sense, given our naive implementation. Indeed, if we recall the attention definition, at the $i$-th token each head computes
\[\begin{equation*} \mathrm{softmax}( \vq_i \vK^T ) \vV, \end{equation*}\]and the output projection just accumulates these independent heads and ‘lifts’ the head dimension $H$ to the output dimension $D$. In other words, if we were to save the previous key and value tokens $(\vk_j, \vv_j)_{j=1}^{i-1}$ in a cache, we could reuse these to compute each head’s attention output with no wasted operations: we’d just need to compute $\vk_i$ and $\vv_i$ for the current position, then load the previous positions’ keys and values from the cache to compute the full attention output.
This data structure is called a KV cache, and it’s essential for fast inference, since autoregressive operation implies we can only compute one token at a time (in contrast to training, where we forward an entire batch of sequences at once). A back-of-the-envelope analysis suggests that the naive sampling strategy (re-forward the entire sequence on every new token generation) takes $O(S^2D^2 + DS^3)$ FLOPs to generate an $S$-length sequence starting from an empty prompt, whereas the KV caching strategy takes only $O(SD^2 + DS^2)$ FLOPs. This is a significant savings!
The catch is that the KV caching strategy requires significantly more memory accesses
and usage compared to the naive approach. In JAX, we also need to be careful to
implement the KV cache ‘correctly’ – since the cache grows after every generation,
naive implementations can entail inputs of different sizes to _attn
(triggering a
re-compilation every time we call it, when it’s jit
ted!), or dynamically-sized arrays
(not allowed within a jit
context!).5
Our simple implementation of the KV cache is below. Since we left the API for the cache
in the code we’ve written already, we just need to implement its functionality in the
_attn
function, and add a function to initialize an empty cache. To play nice with
jit
, we will implement a static-sized cache with a fixed size config.max_seq_len
(the same parameter we used for RoPE), ensuring we can jit
the _attn
function only
once, and maintain an auxilary cache_size
variable to keep track of how much we’ve
written to the cache. To avoid having dynamically-sized arrays in the _attn
function,
we are a bit wasteful and zero-initialize the cache, then just multiply the current
query $\vq_i$ against the entire cache on every attention operation.6 To avoid having
these extra keys influence the attention computation (since softmaxing them produces a
nonzero output!), we augment the attention mask to mask them out.
The implementation is below.
def _attn(
config: Config,
params: Attn,
x_seq: jax.Array,
kv_cache: jax.Array, # 2 x config.max_seq_len x n x h (see below)
cache_size: int,
):
# s: sequence length
# d: embedding dim (config.d_model)
# n: attention heads (config.num_heads)
# h: head dim (config.d_head)
# x_seq: s x d
qkv = jnp.einsum(
"sd,d3nh->3snh",
x_seq,
params.w_qkv,
out_sharding=jax.P(*config.sharding_att_qkv),
)
q, k, v = [qkv[i] for i in range(3)]
s = q.shape[0] # we save k shape later, after possibly prepending cache
# Apply RoPE
if config.use_rope:
if not config.update_cache:
cache_size = 0 # ignore passed value
with jax.ensure_compile_time_eval():
rope_cos, rope_sin = _precompute_rope_sincos(config)
positions = jnp.arange(s, out_sharding=jax.P())
_apply_rope_one_head = partial(
_apply_rope, config, rope_cos, rope_sin, positions + cache_size
)
_apply_rope_all_heads = jax.vmap(_apply_rope_one_head, in_axes=1, out_axes=1)
q, k = _apply_rope_all_heads(q), _apply_rope_all_heads(k)
# Read/update/concatenate the cache
if config.update_cache:
k_cache, v_cache = kv_cache[0], kv_cache[1]
k = jax.lax.dynamic_update_slice(k_cache, k, (cache_size, 0, 0))
v = jax.lax.dynamic_update_slice(v_cache, v, (cache_size, 0, 0))
kv_cache_out = jnp.concatenate((k[None], v[None]), axis=0)
else:
kv_cache_out = kv_cache
cache_size = 0 # ignore passed value
# Attention computation
t = k.shape[0]
mask = (cache_size + jnp.arange(s))[:, None] >= jnp.arange(t)[None, :]
# with static kv cache, must mask unused memory!
mask = mask & (jnp.arange(t)[None, :] < cache_size + s)
mask = mask[None, ...] # broadcast over heads
if config.use_fa:
attn_out = jax.nn.dot_product_attention(q, k, v, scale=None, mask=mask)
else:
logits = jnp.einsum("snh,tnh->nst", q, k).astype(config.compute_dtype)
# Scale and causal mask
logits *= 1.0 / config.d_head**0.5
logits = jnp.where(mask, logits, -jnp.inf)
probs = jax.nn.softmax(logits, axis=2) # type: ignore[reportArgumentType]
probs = probs.astype(config.param_dtype)
attn_out = jnp.einsum("nst,tnh->snh", probs, v)
out = jnp.einsum(
"snh,hnd->sd",
attn_out,
params.w_o,
out_sharding=jax.P(*config.sharding_res_stream),
)
return out, kv_cache_out
def init_kv_cache(config: Config):
return jnp.zeros(
(
config.global_batch_size,
config.num_layers,
2,
config.max_seq_len,
config.num_heads,
config.d_head,
),
dtype=config.param_dtype,
out_sharding=config.sharding_data + jax.P(None) + config.sharding_att_qkv,
)
Running this block and then rerunning our training script shows that our changes haven’t broken anything. We can test how much it speeds up sampling, as well. We modify the sampler above slightly – we enable the cache in the config, and forward one token per sampling step to the model, rather than the entire output, as the cache will store all the context.
config_sampling_args = {"sharding_data": jax.P(), "update_cache": True}
config_sampling = Config(**config_sampling_args)
next_token = jnp.array((1,))
output = next_token
cache = init_kv_cache(config_sampling)[0]
cache_size = 0
tokens_to_generate = 63
temperature = 0.1
for step in range(tokens_to_generate):
key_sampling, sk = jax.random.split(key_sampling)
next_token, cache, cache_size = sample_one_token(
config_sampling,
sk,
train_state.params,
next_token,
cache,
cache_size,
temperature,
)
output = jnp.concatenate((output, next_token))
output
Output:
Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4,
5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8], dtype=int32)
Correct outputs!
The runtime is about 2 seconds including the jit
compilation time on the first run,
and 0.5
seconds on subsequent runs (since the static cache and
one-token-input-per-sampling-step means the jit
ted sampler sees the same input shapes
on every call). Quite a speedup!
Conclusions
In this blog post, we’ve built up:
- A simple
jit
table transformer model with RoPE; - A basic training loop with SGD and toy data (digit alphabet);
- A minimal
jit
table sampler with a static KV cache.
All in bare-metal JAX.
Since the length is already nontrivial, we’ll save further extensions – refactoring into a python package, improving the config scaffolding for running experiments and adding distributed training support, improving the optimizer, and training on actual language data – to the next post. I hope this exposition is helpful to learners of JAX and transformers – writing it has been very helpful to me!
Accumulated Configs
For reference, here’s the implementation of the Config
class we’ve used throughout the
discussion above.
# Architecture: config
@jax.tree_util.register_static
@dataclass(kw_only=True, frozen=True)
class Config:
# Experiment orchestration params
mesh_axis_names: tuple[str, ...] = ("dp",)
mesh_shape: tuple[int, ...] = (4,)
seed: int = 1337
# Data and training params
seq_len: int = 256
global_batch_size: int = 128
num_steps: int = 10**3
lr: float = 1e-2
# Model architecture params
num_vocab: int = 10
d_model: int = 768
num_heads: int = 12
d_head: int = 64
mlp_factor: int = 4
num_layers: int = 2
param_std: float = 0.02
rope_theta: float = 10000.0
max_seq_len: int = 1024
# Model dtypes
param_dtype = jnp.bfloat16 # weights, activations
compute_dtype = jnp.float32 # layernorm, attn logits, rope
optimizer_dtype = jnp.float32 # optimizer state
# Model call-time params
eps_ln: float = 1e-6
use_bias_ln: bool = False
use_fa: bool = True
use_bias_mlp: bool = False
use_rope: bool = True
update_cache: bool = False # default training
# Model sharding params
sharding_data: jax.sharding.PartitionSpec = jax.P("dp")
sharding_wqkv: jax.sharding.PartitionSpec = jax.P()
sharding_wo: jax.sharding.PartitionSpec = jax.P()
sharding_wup: jax.sharding.PartitionSpec = jax.P()
sharding_wdown: jax.sharding.PartitionSpec = jax.P()
sharding_mlp_hidden: jax.sharding.PartitionSpec = jax.P()
sharding_res_stream: jax.sharding.PartitionSpec = jax.P()
sharding_att_qkv: jax.sharding.PartitionSpec = jax.P()
def __post_init__(self):
# Set up and register mesh
mesh = jax.make_mesh(
self.mesh_shape,
self.mesh_axis_names,
len(self.mesh_shape) * (jax.sharding.AxisType.Explicit,),
)
jax.sharding.set_mesh(mesh)
# Checks
assert self.d_head % 2 == 0, (
"Head dimension needs to be divisible by 2 for RoPE"
)
Acknowledgments
Thanks to the TRC program for compute support.
-
Namely, because we won’t have a clean way to combine our model parameter definitions with the actual application of the model. ↩
-
The causal mask actually is enough to enable sequence-to-sequence mappings that depend on position to be learned in multi-layer transformers. The recent trend of “NoPE” (no positional encoding) in open-source large models is evidence of this. ↩
-
It’s actually a block-diagonal rotation matrix, with $2 \times 2$ blocks. This is what leads to the efficient approach to computing its induced quadratic form below! ↩
-
In code, for a list
v
, this is justv[H//2:] + v[:H//2]
. Moreover, the choice of the permutation $\vPi$ is not essential (because the coordinates of the embedded keys and queries follow a learnable linear transformation), and can be changed for either mathematical clarity or code clarity (we choose based on the latter; the original RoPE paper emphasized the former). ↩ -
Actually, our naive sampling implementation above, which re-forwards the entire sequence after concatenating, runs so slow because the input to the transformer keeps growing in size – which triggers a recompilation after every generation! Avoiding these pitfalls is a general challenge in JAX. ↩
-
Although this feels very wasteful, it actually only loses (asymptotically) a factor of $2$ in FLOPs versus the optimal implementation. ↩
References
- (2025). The Training Cookbook. Retrieved from https://docs.jax.dev/en/latest/the-training-cookbook.html. (Accessed 2025-08-29)