In this second post in a series on programming with JAX for training models like transformers, we write and train a transformer with a decent set of bells and whistles, then benchmark and scale it as much as we can on our TPU v4 half-cube!

As in previous posts in the series, we make good use of the Google scaling playbook (Austin et al., 2025), as well as some useful JAX tutorials, especially (JAX Team, 2025).

Setup

As usual, we’ll test things out in Python 3.13 and JAX 0.7.1 on a single TPU v4 host.


import jax
import jax.numpy as jnp

jax.devices()
Output:
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

Design Choices

A project like this is for learning and pedagogy—it’s inevitably behind the research/infra frontiers. So the code we write is going to lean into this somewhat. Notably,

We’ll implement everything in bare-metal JAX, rather than using neural network libraries like Flax/nnx/Equinox.

This will let us expose and understand all the implementation details, both for infra and architecture, that we might normally overlook. On the flip side, we’ll take some performance hits for this (which we’ll attempt to profile and characterize), and it will lead to some slightly unfortunate code repetition.1 We’ll also be reinventing the wheel throughout – but for pedagogy’s sake!

Model Architecture Overview

Our approach to implementing the network architecture is a modification of the approach used in the JAX training cookbook (JAX Team, 2025), which uses a data structure like those in Google’s ML Collections library to store the model. At a high level, this approach consists of these standard JAX neural net patterns:

  • Storing the model parameters as a pytree.
  • Defining separate functions for initialization and the model’s forward pass which take such a pytree as input (as well as other relevant inputs). This pattern promotes functionally-pure code for the model’s initialization and forward pass, making it easy to jit for performance.

Our implementation tries to move this slightly in the direction of Equinox, without all the powerful features that library entails. In particular, we will:

  • Use NamedTuples for parameters and layers (i.e., where you’d normally expect to use a nn.Module in Pytorch), enabling good static type checking with e.g. Pyright. NamedTuples are automatically registered as pytrees in JAX!
  • Write a top-level function for model initialization, and layer-level functions for each separate layer’s operation (which, as above, take parameters and inputs as arguments). This is a ‘bottom-up’ version of the previous design, where we’ll end up with many different functions, one for each ‘layer’, which are combined bottom-up to implement the overall model’s forward pass (more like in a typical neural network library).

Building the Model: First Cut

When writing this post, I wrote and debugged the initial model/training code iteratively in a notebook. Rather than just jumping to the ‘final answer’, we’ll walk through the development process below.

In this part of the post, we’ll build up to the transformer through a few simple sub-modules of the overall model. Our targets will be:

  • We’ll work with a very simple data distribution, described below, which requires minimal dataloading code (and no tokenization).
  • We’ll focus on the training loss, leaving sampling for later. Although simple sampling gives us an important correctness test (“did we accidentally give the model access to future tokens?”, for example…), doing it ‘fast’ requires a bit of code, so we’ll save it for later.

On the way to these targets, we’ll build up the following:

  • Just the MLP, giving us a position-wise model (bigrams with a nonlinearity!).
  • A training loop with SGD and cross-entropy (next token prediction) loss.
  • Just causal attention, letting us build a transformer that can solve our toy task (discussed in the next section) with two layers.
  • Rotary positional encodings, so that we can solve our task with a one-layer model.

We’ll then step back and refactor things with an eye towards training larger models on actual language data in a distributed setting.

If you’re following along and you notice anything I’ve implemented that seems suboptimal, please let me know! Drop me an email or a DM on Twitter (links at the bottom of the page).

A Quick Note on Configuration

For overall convenience, we eventually want to combine all our different configuration options in a large dataclass, then pass our runtime instantiation of the config (with defaults, overrides, etc.) to all our functions, which can extract just the parameters they need.

When writing this code, I tend to define the config options I need locally, then go back later and turn them into attributes of the overall config dataclass. To avoid frontloading with all the complexity of the config, and to avoid too much unnecessary description of the refactoring process, we’ll reference different config options below as attributes of an (as of yet undefined) Config class. Code outputs will also have this class defined. You can Ctrl+F for type definitions and defaults, which will appear later, if it feels necessary.

Brief Setup

We are going to make reference to different sources of randomness below, which we seed from our Config class.

# Test it out
config = Config()
key = jax.random.key(config.seed)
key_params, key_data, key_sampling = jax.random.split(key, 3)

Data

The data will consist of an (infinite) sequence of consecutive digits (from 0 to 9), which ‘reflect’ when they reach 0 or 9. I.e.,

01234567898765432101...

We’ll generate a dataset of fixed-length subsequences of this infinite sequence for training.

text = "012345678987654321" * 1024
seqs = [
    [int(c) for c in text[i : i + config.seq_len + 1]]
    for i in range(len(text) - config.seq_len - 1)
]
data = jnp.array(seqs, dtype=jnp.int32)


key_data, sk = jax.random.split(key_data)
data = jax.random.permutation(sk, data, axis=0)
num_data = len(data)
print(data[128])

Xtr = data[: int(0.8 * num_data)]
Xdev = data[int(0.8 * num_data) : int(0.9 * num_data)]
Xte = data[int(0.9 * num_data) :]
Xtr.shape
Output:
[8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8
 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9
 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8
 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7
 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6
 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5
 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6]
(14540, 257)

We shuffle the sequences above and perform a simple train-dev-test split (although we’ll just use train for now). We take one more character than the sequence length, so that we can easily turn these raw sequences into inputs and targets. We do this with a simple dataloader:

import itertools as it


def dataloader(key, config: Config):
    for step in it.count():
        key = jax.random.fold_in(key, step)
        offsets = jax.random.randint(key, (config.global_batch_size,), 0, num_data)
        yield (Xtr.at[offsets, :-1].get(), Xtr.at[offsets, 1:].get())


key_data, sk = jax.random.split(key_data)
batch = map(
    lambda batch: jax.device_put(batch, config.sharding_data),
    dataloader(sk, config),
)

Here batch is a map iterator which wraps the dataloader with a device_put statement to enforce our desired sharding (e.g., data parallel). When we create our config class, we create a global mesh and set it as the default mesh, so we can pass a PartitionSpec object to device_put in our single-host setting (see the previous post).

The batches being loaded have shape (config.global_batch_size, config.seq_len), so expect config.sharding_data to be for example jax.P('dp') for a “data-parallel” mesh axis 'dp', which leads the batch to be sharded across accelerators.

Defining the Model: MLP

We’ll implement the overarching structure of our model here, described above in the “Design Choices” section, but keep the low-level implementations restricted to just the MLP.

We’ll follow the following conventions:

  • “Layers” correspond to NamedTuples (which are pytrees). The elements of such a “layer” can be either jax.Arrays (parameters/pytree leaves), or other NamedTuples (for composing layers).
  • We’ll name our layers with capital letters, and the functions that implement them (i.e., their “forward” method in Pytorch lingo) in lower-case, with a leading underscore.

Here’s our MLP layer following these conventions:

from typing import NamedTuple


class Mlp(NamedTuple):
    w_up: jax.Array
    bias_up: jax.Array
    w_down: jax.Array
    bias_down: jax.Array


def _mlp(config: Config, params: Mlp, x: jax.Array):
    preact = jnp.dot(x, params.w_up, out_sharding=config.sharding_mlp_hidden)
    if config.use_bias_mlp:
        preact += params.bias_up
    act = jax.nn.gelu(preact, approximate=True)
    out = jnp.dot(act, params.w_down, out_sharding=config.sharding_res_stream)
    if config.use_bias_mlp:
        out += params.bias_down
    return out

We would create parameters for such a layer and apply it as follows:

my_mlp = Mlp(
    w_up=jnp.zeros((2, 4)),
    bias_up=jnp.zeros((4,)),
    w_down=jnp.zeros((4, 2)),
    bias_down=jnp.zeros((2,)),
)
_mlp(config, my_mlp, jnp.ones((2,)))
Output:
Array([0., 0.], dtype=float32)

It’s important to note here that we are not able to enforce shape information at this level of implementation, since abstractly speaking an MLP can be implemented for any compatible weight matrix dimensions. These shapes get enforced later, when we initialize an Mlp with settings specified in our config instance.

More importantly, there is some ambiguity in the types of the arguments x and params that _mlp will apply to – e.g. the above function could apply to x being a vector of shape (d_model,), or a matrix of shape (seq_len, d_model), as long as the parameters of Mlp are initialized correctly. In other words, it’s up to us to make sure we apply this implementation correctly.

In our subsequent implementation, we’re going to follow a helpful principle that JAX affords – we’ll operate at the finest level of granularity possible at every stage of the model, and ‘move up’ to the next level of granularity as necessary, using JAX’s powerful primitives like vmap and scan. In particular, we’ll see soon that we’ll only use the above _mlp function in cases where x is a vector of shape (d_model,). Since the MLPs in transformers operate the same on every position in the sequence, and every sequence in the batch, we’ll just vmap our _mlp function when we need to use it at the next level of granularity! This is a pretty empowering design strategy that JAX enables: it lets us write our code’s core numerical functionality without having to plan ahead for exactly how it will be used later, making implementation simpler and separating complexity nicely.

A few other notes (some drawbacks of this bare metal approach, which seem unvaoidable):

  • There’s no ‘state’ in the _mlp function – it just performs the operation of a MLP given parameters/inputs. This makes it easy to jit.
  • Initialization is not handled here – we’ll have to do it later, since we want to initialize the entire transformer (top-down) in one shot.
  • We enforce shardings for the hidden activations/output of the MLP in the _mlp function. We enforce shardings for the parameters of the MLP when we initialize them (see later).

Defining the Model: Embeddings and Placeholders

To have a basic working model, we need token embeddings (mapping the raw integer tokens we generated above to vectors the transformer can operate on), unembeddings (mapping the vector outputs of the transformer to logits, which we can map to a probability distribution over tokens to predict with), and normalization layers. We’ll have attention layers too, eventually; for now, we’ll make a placeholder function for the attention layer.

class Attn(NamedTuple):
    w_qkv: jax.Array
    w_o: jax.Array


def _attn(
    config: Config,
    params: Attn,
    x_seq: jax.Array,
    kv_cache: jax.Array,
    cache_size: int,
):
    return x_seq, kv_cache


class Embedding(NamedTuple):
    w: jax.Array


def _embedding(config: Config, params: Embedding, token: jax.Array):
    emb = params.w.at[token].get(out_sharding=config.sharding_res_stream)
    return emb


class Unembedding(NamedTuple):
    w: jax.Array


def _unembedding(config: Config, params: Unembedding, x: jax.Array):
    logits = jnp.dot(x, params.w, out_sharding=config.sharding_res_stream)
    return logits


class LayerNorm(NamedTuple):
    gamma: jax.Array
    beta: jax.Array


def _layernorm(config: Config, params: LayerNorm, x: jax.Array):
    x_std = jax.nn.standardize(x.astype(config.compute_dtype), epsilon=config.eps_ln)
    out = params.gamma * x_std.astype(config.param_dtype)
    if config.use_bias_ln:
        out += params.beta
    return out

These are standard unoptimized implementations for these different building blocks. Looking ahead slightly, we make sure to provide the ability to compute the layer normalization operation in a higher precision (config.compute_dtype), since the variance operation in standardize can lead to underflow and other losses of precision. The API for _attn involves arguments for utilizing a cache for faster inference (we’d normally leave this out to begin with, and add it only later).

Defining the Model: Building Blocks to Transformer

A transformer consists of an embedding, then a sequence of transformer blocks (which involve normalization + attention and a residual connection, then normalization + MLP and a residual connection), then the unembedding. Accordingly, we make a Block class which combines our low-level building blocks.

from functools import partial


class Block(NamedTuple):
    norm_attn: LayerNorm
    attn: Attn
    norm_mlp: LayerNorm
    mlp: Mlp


def _block(
    config: Config,
    params: Block,
    x_seq: jax.Array,
    cache_in: jax.Array,
    cache_size: int,
):
    att_skip = x_seq
    out = jax.vmap(partial(_layernorm, config, params.norm_attn))(x_seq)
    out, cache_out = _attn(
        config, params.attn, out, kv_cache=cache_in, cache_size=cache_size
    )
    out += att_skip

    mlp_skip = out
    out = jax.vmap(partial(_layernorm, config, params.norm_mlp))(out)
    out = jax.vmap(partial(_mlp, config, params.mlp))(out)
    out += mlp_skip

    return out, cache_out

Notice that we’re following the design principle we mentioned above with the MLP: here we expect x_seq to have two dimensions, the first for the sequence length and the second for the model dimension, and we vmap the _mlp and _layernorm functions to apply to it correctly.

We can now define the transformer. We’ll do this in a slightly refined way versus just using a for loop: we’ll expect the blocks parameter below, of type Block, to have a mapped leading axis corresponding to the layer dimension, and we’ll consume this leading axis using a scan. That means we expect every parameter in the Block pytree we pass to not just have its usual shape, say param.shape (expected by the _mlp, etc. APIs above), but additionally be of shape (num_layers,) + param.shape. This turns out to be very easy to guarantee with a vmap at initialization (see below).

If you aren’t familiar with scan, you can read about the API here. It takes a function that accepts two arguments – a ‘carry’ and an ‘input’ – and produces one output, as well as an ‘initial’ carry, and an initial input, which must have an additional mapped axis somewhere relative to the input shape the function argument expects. It will then iteratively apply the function to the current carry and the current input; the output carry is passed to the next call of the function. In this way, it’s like a vectorized for loop – you can also draw a direct analogy to a recurrent neural network, or other dynamical systems.

In our case, the ‘carry’ for scan is our network’s input and activations (we pass these from layer to layer sequentially), and the ‘input’ for scan is the mapped Block pytree (as well as the mapped KV cache for inference – the updated cache becomes the output – but we can ignore that for now).

class Transformer(NamedTuple):
    emb: Embedding
    blocks: Block  # vmapped at init
    unemb: Unembedding


def _transformer(
    config: Config,
    params: Transformer,
    tokens: Array,
    cache_in: jax.Array,
    cache_size: int = 0,
):
    x_seq = jax.vmap(partial(_embedding, config, params.emb))(tokens)

    def _block_fun(x_seq: Array, params__cache_in: tuple[Block, jax.Array]):
        params, cache_in = params__cache_in
        return _block(config, params, x_seq, cache_in, cache_size)

    out, cache_out = jax.lax.scan(_block_fun, x_seq, (params.blocks, cache_in))

    out = jax.vmap(partial(_unembedding, config, params.unemb))(out)

    return out, cache_out

Using scan leads to a very concise forward pass definition for the transformer. Note that our transformer is still operating on sequences of embeddings – we’ll apply it to a batch of sequences with vmap later!

Defining the Model: Initialization

The last step before testing is to initialize the model. This leads to a long definition, but the intrinsic complexity is not high.

def init_model_params(key, config: Config) -> Transformer:
    def init_embedding(key) -> Embedding:
        emb = config.param_std * jax.random.normal(
            key,
            (config.num_vocab, config.d_model),
            config.param_dtype,
            out_sharding=config.sharding_res_stream,
        )
        return Embedding(w=emb)

    def init_unembedding(key) -> Unembedding:
        unemb = config.param_std * jax.random.normal(
            key,
            (config.d_model, config.num_vocab),
            config.param_dtype,
            out_sharding=config.sharding_res_stream,
        )
        return Unembedding(w=unemb)

    def init_mlp(key) -> Mlp:
        k_w_up, k_w_down = jax.random.split(key, 2)
        w_up = config.param_std * jax.random.normal(
            k_w_up,
            (config.d_model, config.mlp_factor * config.d_model),
            config.param_dtype,
            out_sharding=config.sharding_wup,
        )
        w_down = config.param_std * (
            jax.random.normal(
                k_w_down,
                (config.mlp_factor * config.d_model, config.d_model),
                config.param_dtype,
                out_sharding=config.sharding_wdown,
            )
        )
        bias_up = jnp.zeros(
            (config.mlp_factor * config.d_model,),
            config.param_dtype,
            out_sharding=config.sharding_mlp_hidden,
        )
        bias_down = jnp.zeros(
            (config.d_model,),
            config.param_dtype,
            out_sharding=config.sharding_res_stream,
        )
        return Mlp(w_up=w_up, bias_up=bias_up, w_down=w_down, bias_down=bias_down)

    def init_attn(key) -> Attn:
        k_qkv, k_o = jax.random.split(key, 2)
        w_qkv = config.param_std * jax.random.normal(
            k_qkv,
            (config.d_model, 3, config.num_heads, config.d_head),
            config.param_dtype,
            out_sharding=jax.P(*config.sharding_wqkv),
        )
        w_out = config.param_std * (
            jax.random.normal(
                k_o,
                (config.d_head, config.num_heads, config.d_model),
                config.param_dtype,
                out_sharding=jax.P(*config.sharding_wo),
            )
        )
        return Attn(w_qkv=w_qkv, w_o=w_out)

    def init_layernorm() -> LayerNorm:
        gamma = jnp.ones(
            (config.d_model,),
            config.param_dtype,
            out_sharding=config.sharding_res_stream,
        )
        beta = jnp.zeros(
            (config.d_model,),
            config.param_dtype,
            out_sharding=config.sharding_res_stream,
        )
        return LayerNorm(gamma=gamma, beta=beta)

    def init_block(key) -> Block:
        key_attn, key_mlp = jax.random.split(key)
        return Block(
            norm_attn=init_layernorm(),
            attn=init_attn(key_attn),
            norm_mlp=init_layernorm(),
            mlp=init_mlp(key_mlp),
        )

    # Make the full network
    key_emb, key_blocks, key_unemb = jax.random.split(key, 3)
    keys_blocks = jax.random.split(key_blocks, config.num_layers)
    return Transformer(
        emb=init_embedding(key_emb),
        blocks=jax.vmap(init_block)(keys_blocks),
        unemb=init_unembedding(key_unemb),
    )

Notice above how we produce the mapped axis needed for the scan in the _transformer function above: we just split our PRNGKey into an array of keys, one for each layer, and vmap the init_block function over the array of keys. Easy!

Testing the Basic Model

We can perform a quick test to make sure everything above is at least syntactically correct. There are two things we want to see:

  1. Can we run the model in eager mode?
  2. Can we jit the model?

We really only care about 2 (it implies 1), so let’s test it.

key_params, sk = jax.random.split(key_params)
tf = init_model_params(sk, config)

print("Number of layers in model: ", config.num_layers)
print("Model dimension: ", config.d_model)

cache = None
tokens = jnp.arange(10)


@jax.jit
def model(params, tokens):
    return _transformer(config, params, tokens, cache, 0)


out, out_cache = model(tf, tokens)
out
Output:
Number of layers in model:  2
Model dimension:  768
Array([[0.894531, -0.482422, 1.41406, 0.194336, -2.76562, -1.27344,
        1.21094, 0.417969, -0.460938, -0.304688],
       [0.0610352, 2.8125, 1.08594, -0.351562, -0.326172, -1.46094,
        0.0354004, 0.283203, 0.796875, -1.42969],
       [0.186523, 1.21875, 0.429688, -1.10156, 0.667969, -2.07812,
        -0.078125, 0.410156, 0.0273438, -2.60938],
       [0.201172, 1.46094, 0.124512, -1.27344, -1.05469, -0.667969,
        0.453125, 1.78906, -1.00781, 1.67188],
       [-0.925781, 0.988281, -1.14844, 0.0942383, 0.0898438, 1.70312,
        1.125, 1.03125, -1.17969, -1.69531],
       [-0.710938, -0.679688, -1.53125, 0.800781, 0.443359, -0.71875,
        -2.29688, 0.488281, -0.169922, -1.38281],
       [1.44531, 0.949219, -0.589844, 0.357422, -1.30469, 2.85938,
        -0.388672, 0.46875, -0.695312, -0.425781],
       [-1.85156, 0.679688, 1.25, 0.109863, -0.667969, -0.361328,
        -0.0405273, -0.964844, 3.01562, 0.333984],
       [-0.304688, 1.83594, 0.734375, 1.16406, -0.114746, -0.800781,
        0.486328, 1, -1.00781, -1.01562],
       [-1.66406, 0.925781, -0.0908203, 1, 0.298828, -2.40625, 0.574219,
        0.945312, -2.54688, -0.123047]], dtype=bfloat16)

Looks good! There might be a small issue with the unembedding scaling, as we can see:

jax.nn.softmax(out, axis=-1)
Output:
Array([[0.163086, 0.0410156, 0.275391, 0.0810547, 0.00418091, 0.0186768,
        0.224609, 0.101562, 0.0419922, 0.0493164],
       [0.0390625, 0.613281, 0.108887, 0.026001, 0.0264893, 0.00848389,
        0.0380859, 0.0488281, 0.081543, 0.00872803],
       [0.100098, 0.279297, 0.126953, 0.02771, 0.161133, 0.010376,
        0.0766602, 0.124512, 0.0849609, 0.006073],
       [0.0583496, 0.204102, 0.0539551, 0.0133057, 0.0164795, 0.0244141,
        0.0751953, 0.285156, 0.017334, 0.253906],
       [0.0227051, 0.15332, 0.0181885, 0.0629883, 0.0629883, 0.314453,
        0.176758, 0.160156, 0.0177002, 0.010437],
       [0.0588379, 0.060791, 0.026123, 0.267578, 0.1875, 0.0588379,
        0.012146, 0.195312, 0.101562, 0.0300293],
       [0.141602, 0.0864258, 0.0184326, 0.0476074, 0.00909424, 0.582031,
        0.022583, 0.0534668, 0.0164795, 0.0218506],
       [0.00500488, 0.0629883, 0.112305, 0.0358887, 0.0164795, 0.0224609,
        0.0307617, 0.012207, 0.65625, 0.0444336],
       [0.0395508, 0.335938, 0.111328, 0.171875, 0.0473633, 0.0239258,
        0.0869141, 0.145508, 0.0194092, 0.0194092],
       [0.0145874, 0.193359, 0.0693359, 0.208008, 0.102539, 0.00689697,
        0.135742, 0.196289, 0.00598145, 0.0673828]], dtype=bfloat16)

In particular, the initial outputs are not very uniform, which we would like to have. However, we should still be able to learn – our network is not very deep.

A Minimal Training Loop

Now we have enough to train a model. A simple training loop follows, using SGD. First, we set up the training step, jitting it for performance.

class TrainState(NamedTuple):
    params: Transformer
    opt: None


@jax.jit
def init_train_state(key, config: Config) -> TrainState:
    return TrainState(params=init_model_params(key, config), opt=None)


train_state = init_train_state(key_params, config)
cache = None


@partial(jax.jit, donate_argnums=2)
def train_step(config: Config, batch, train_state: TrainState):
    def loss_fn(params: Transformer):
        inputs, targets = batch
        logits, cache_out = jax.vmap(partial(_transformer, config, params))(
            inputs, cache
        )
        logits = logits.astype(config.compute_dtype)
        logprobs = jax.nn.log_softmax(logits, axis=-1)
        return -jnp.take_along_axis(logprobs, targets[..., None], axis=-1).mean()

    loss, grad = jax.value_and_grad(loss_fn)(train_state.params)
    new_state = TrainState(
        params=jax.tree.map(lambda x, y: x - config.lr * y, train_state.params, grad),
        opt=None,
    )

    metrics = {"loss": loss}
    return metrics, new_state

In the training step, after loading a batch of sequences from our previously-written batch loader, we vmap _transformer over it, then calculate the next-token prediction loss with the standard bare-metal approach (with possible up-casting for numerical stability of the softmax). The donate_argnums argument to jit lets us mark that the train_state buffer can be discarded after the step finishes (since we create and return a new TrainState with the results of the SGD step), letting the XLA compiler reuse that memory for the output state.

The loop is then simple to write:

prev_metrics = None
for step in range(config.num_steps):
    cur_metrics, train_state = train_step(config, next(batch), train_state)
    log_metrics, prev_metrics = prev_metrics, cur_metrics
    if log_metrics and step % 50 == 0:
        log_metrics |= {"step": step}
        print(*[f"{metric}: {val}" for metric, val in log_metrics.items()], sep="\t")
Output:
loss: 1.1560431718826294	step: 50
loss: 0.9151328206062317	step: 100
loss: 1.004440188407898	    step: 150
loss: 0.8609314560890198	step: 200
loss: 0.9039498567581177	step: 250
loss: 0.8058883547782898	step: 300
loss: 0.7688642144203186	step: 350
loss: 0.6995882987976074	step: 400
loss: 0.6380590796470642	step: 450
loss: 0.6304700374603271	step: 500
loss: 0.6276624798774719	step: 550
loss: 0.6255137920379639	step: 600
loss: 0.6255109310150146	step: 650
loss: 0.6271587610244751	step: 700
loss: 0.6273760795593262	step: 750
loss: 0.6262620687484741	step: 800
loss: 0.6236293911933899	step: 850
loss: 0.6241434216499329	step: 900
loss: 0.6261416673660278	step: 950

It’s not great, but it’s indeed training! We don’t expect to be able to solve this task perfectly with our simple model, since a position-wise MLP can only use one token of context to predict the next token, and our task requires at least two. To see why, just note that, for example, the character following 5 could be either 6 or 4, depending on whether we’re at a rising or falling part of the sequence. Two characters of context is enough to estimate this, meaning adding attention should help here!

As an aside, above, we make sure to print metrics for the previous iteration of the training loop at each step. This helps with correctly pipelining host to data transfers: see (JAX Team, 2025) for a detailed description.

Sampling from the Model: First Cut

To get a sense of the behavior of our learned model, we can write our sampler for it now. Normally, efficient sampling requires a cache of past outputs to be maintained (we already have the API for this), but we’ll discuss this later after we implement attention.

@partial(jax.jit, donate_argnums=(4,))
def sample_one_token(
    config: Config,
    key,
    params: Transformer,
    x: jax.Array,
    cache_in: jax.Array,
    cache_size: int,
    temperature: float,
):
    y, cache_out = _transformer(config, params, x, cache_in, cache_size)
    logits = y.astype(config.compute_dtype)
    cache_size = cache_size + x.shape[-1]
    next_token = jnp.array((jax.random.categorical(key, logits[-1] / temperature),))
    return next_token, cache_out, cache_size
config_sampling_args = {"sharding_data": jax.P()}
config_sampling = Config(**config_sampling_args)

output = jnp.array((1,))
cache = None
tokens_to_generate = 63
temperature = 0.1
for step in range(tokens_to_generate):
    key_sampling, sk = jax.random.split(key_sampling)
    next_token, cache, cache_size = sample_one_token(
        config_sampling, sk, train_state.params, output, cache, 0, temperature
    )
    output = jnp.concatenate((output, next_token))
output
Output:
Array([1, 2, 3, 2, 3, 4, 3, 2, 3, 2, 3, 5, 6, 5, 4, 3, 4, 5, 2, 1, 2, 3,
       2, 1, 0, 1, 2, 3, 4, 3, 4, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 5, 6, 5,
       6, 5, 1, 2, 0, 1, 0, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 4, 3, 2],      dtype=int32)

We set a relatively high temperature in order to better highlight what the model has learned. The generated sequence is in line with our hypothesis – the model has lowered the loss by learning that it needs to predict the number that’s one above or one below the current input, but it can’t do much better than guessing randomly between these two possibilities (hence why our final loss is close to $\log(2)$).

Let’s improve this with attention!

Building the Model: Attention and Positional Encodings

An attention implementation that follows jax.nn.dot_product_attention shape conventions follows, with both manual implementations and one using this function. For now, we omit the cache implementation and positional embedding implementation.

def _attn(
    config: Config,
    params: Attn,
    x_seq: jax.Array,
    kv_cache: jax.Array,
    cache_size: int,
):
    # s: sequence length
    # d: embedding dim (config.d_model)
    # n: attention heads (config.num_heads)
    # h: head dim (config.d_head)
    # x_seq: s x d

    qkv = jnp.einsum(
        "sd,d3nh->3snh",
        x_seq,
        params.w_qkv,
        out_sharding=jax.P(*config.sharding_att_qkv),
    )
    q, k, v = [qkv[i] for i in range(3)]
    s = q.shape[0]

    # Attention computation
    t = k.shape[0]
    mask = (cache_size + jnp.arange(s))[:, None] >= jnp.arange(t)[None, :]
    mask = mask[None, ...]  # broadcast over heads
    if config.use_fa:
        attn_out = jax.nn.dot_product_attention(q, k, v, scale=None, mask=mask)
    else:
        logits = jnp.einsum("snh,tnh->nst", q, k).astype(config.compute_dtype)
        # Scale and causal mask
        logits *= 1.0 / config.d_head**0.5
        logits = jnp.where(mask, logits, -jnp.inf)
        probs = jax.nn.softmax(logits, axis=2)  # type: ignore[reportArgumentType]
        probs = probs.astype(config.param_dtype)
        attn_out = jnp.einsum("nst,tnh->snh", probs, v)
    out = jnp.einsum(
        "snh,hnd->sd",
        attn_out,
        params.w_o,
        out_sharding=jax.P(*config.sharding_res_stream),
    )

    return out, kv_cache

Attention takes the embeddings of the input sequences $\vX \in \bbR^{S \times D}$ as input. It computes $N$ “attention heads”, then combines them to form the output. Each head, indexed by $n \in [N]$, is generated from different projections of the input embedding sequences: we calculate

\[\begin{equation*} \vQ_n = \vX \vW_{Q, n}, \quad \vK_n = \vX \vW_{K, n}, \quad \vV_n = \vX \vW_{V, n}, \end{equation*}\]

where the output dimension of the projections is $H$, then generate the head output as

\[\begin{equation*} \vA_n = \mathrm{softmax}(\vQ_n \vK_n^\top + \vM) \vV_n, \end{equation*}\]

with the softmax being taken along rows, and $\vM$ denoting a causal mask:

\[M_{ij} = \begin{cases} 0 & i \leq j \\ -\infty & \ow. \end{cases}\]

The heads are then concatenated together along the embedding dimension, producing an $S \times NH$ dimensional result (usually $NH = D$), and then an output projection $\vW_O$ is applied to produce the attention output, which has the same shape as the input $\vX$. The code above implements these operations concisely with jnp.einsum.

We can replace our previous placeholder attention implementation with a complete one and then rerun our training loop, since we already added the proper initialization code above.

Output:
loss: 0.6685549020767212	step: 50
loss: 0.58647620677948	step: 100
loss: 0.5948932766914368	step: 150
loss: 0.5258455872535706	step: 200
loss: 0.5094199776649475	step: 250
loss: 0.40264642238616943	step: 300
loss: 0.42422279715538025	step: 350
loss: 0.48532211780548096	step: 400
loss: 0.45087528228759766	step: 450
loss: 0.3302219808101654	step: 500
loss: 0.39742058515548706	step: 550
loss: 0.3917175531387329	step: 600
loss: 0.3619387745857239	step: 650
loss: 0.34305423498153687	step: 700
loss: 0.2451237142086029	step: 750
loss: 0.27099400758743286	step: 800
loss: 0.4500095248222351	step: 850
loss: 0.26931577920913696	step: 900
loss: 0.2725009322166443	step: 950
Output:
loss: 0.15929250419139862	step: 50
loss: 0.14883831143379211	step: 100
loss: 0.165802463889122	step: 150
loss: 0.14946813881397247	step: 200
loss: 0.15420110523700714	step: 250
loss: 0.14391759037971497	step: 300
loss: 0.2680056691169739	step: 350
loss: 0.14069046080112457	step: 400
loss: 0.27090030908584595	step: 450
loss: 0.14207065105438232	step: 500
loss: 0.13676664233207703	step: 550
loss: 0.13658639788627625	step: 600
loss: 0.13079527020454407	step: 650
loss: 0.39731210470199585	step: 700
loss: 0.1367562860250473	step: 750
loss: 0.13199977576732635	step: 800
loss: 0.1323813498020172	step: 850
loss: 0.12948714196681976	step: 900
loss: 0.13398176431655884	step: 950

Two back-to-back output traces are posted above. We can get a sense that this model should be able to learn to solve the task, but the poor optimizer (constant learning rate SGD) is holding it back. If we rerun the sampling code above with the new model at low temperature (temperature=0.1), we do get correct outputs:

Output:
Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4,
       5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8,
       9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8],      dtype=int32)

Yet the final loss is suggestive that the model we’ve learned is not very good. We can improve the model further by incorporating positional encodings. We’ll use the modern choice, rotary positional encodings (RoPE). There are already many great blogs explaining the intuition behind these encodings, so I’ll just focus here on the algebraic ideas.

The basic idea is to consider the dot product of a query-key pair forming one head’s attention matrix $\vA$, as we defined it above (dropping the subscript $n$ for concision). Ignoring the mask, for the $ij$-th entry of this attention matrix, we consider the dot product (“attention logits”) $\vq_i \vk_j^\top$. A notable property of attention is that the only dependence of its output on the position $i$ is through the causal mask $\vM$: if we drop the mask, then permuting the input positions produces a corresponding permutation of the output positions, because every position receives the same projections $\vW_{QKV}$.2

This is arguably undesirable: the model benefits in many cases from the ability to learn a sequence-to-sequence mapping that is more position sensitive. For example, in our toy data setting, the model would be able to learn a good mapping faster if it were able to selectively focus on only a small number of past tokens (since this is enough to determine whether we’re in a “rising” or “falling” part of the sequence).

A nice way to incorporate this positional information is through a relative positional embedding. In such a method, we modify the attention operation so that the attention logits incorporate the magnitude of the difference in positions $\abs{i - j}$. In RoPE, this is done by inducing a quadratic form

\[\begin{equation*} \vq_i \vR_{\abs{i-j}} \vk_j^\top, \end{equation*}\]

where $\vR_{\abs{i-j}}$ is a specific rotation matrix.3 RoPE’s choice of this matrix has the useful property that for fixed-norm keys and queries, the magnitude of the encoded attention logits decays as a function of $\abs{i-j}$, giving a priori less weight to logits of positions that are further apart.

It also has the essential property that the encoded logits can be computed efficiently. It turns out that there are (real-valued) sequences $\vc_i$, $\vs_i$ corresponding to each position, such that the following holds. Let $\vPi \in \bbR^{H \times H}$ denote the permutation matrix that reverses the top $H/2$ coordinates of a vector with the bottom $H/2$ coordinates (assume $H$ is even).4 Then one has

\[\begin{equation*} \vq_i \vR_{\abs{i-j}} \vk_j^\top = (\vq_i \circ \vc_i + \vPi \vq_i \circ \vs_i) (\vk_j \circ \vc_j + \vPi \vk_j \circ \vs_j)^\top, \end{equation*}\]

where matrix multiplication binds before elementwise multiplication $\circ$. In particular, we can compute the encoded logits via a small number of elementwise multiplications, adds, and concatenations involving the raw keys and queries, which leads to a minimal-complexity implementation. The sequences $\vc_i, \vs_i$ depend only on position, and can be precomputed and reused.

We give a simple implementation of RoPE below, and its integration into our attention layer we implemented previously.

def _precompute_rope_sincos(config: Config):
    freqs = jnp.exp(
        -jnp.log(config.rope_theta).astype(config.compute_dtype)
        * 2
        / config.d_head
        * jnp.arange(config.d_head // 2, out_sharding=jax.P())
    )
    positions = jnp.arange(config.max_seq_len, out_sharding=jax.P())
    cycles = positions[:, None] * freqs[None, :]
    return jnp.cos(cycles), jnp.sin(cycles)


def _apply_rope(
    config: Config,
    cos: jax.Array,
    sin: jax.Array,
    positions: jax.Array,
    x: jax.Array,  # S x H
):
    x_1, x_2 = x[:, : config.d_head // 2], x[:, config.d_head // 2 :]
    c, s = cos.at[positions].get(), sin.at[positions].get()
    return jnp.concatenate(
        (c * x_1 - s * x_2, c * x_2 + s * x_1), axis=-1, dtype=config.param_dtype
    )


def _attn(
    config: Config,
    params: Attn,
    x_seq: jax.Array,
    kv_cache: jax.Array,
    cache_size: int,
):
    # s: sequence length
    # d: embedding dim (config.d_model)
    # n: attention heads (config.num_heads)
    # h: head dim (config.d_head)
    # x_seq: s x d

    qkv = jnp.einsum(
        "sd,d3nh->3snh",
        x_seq,
        params.w_qkv,
        out_sharding=jax.P(*config.sharding_att_qkv),
    )
    q, k, v = [qkv[i] for i in range(3)]
    s = q.shape[0]

    # Apply RoPE
    if config.use_rope:
        if not config.update_cache:
            cache_size = 0  # ignore passed value
        with jax.ensure_compile_time_eval():
            rope_cos, rope_sin = _precompute_rope_sincos(config)
        positions = jnp.arange(s, out_sharding=jax.P())
        _apply_rope_one_head = partial(
            _apply_rope, config, rope_cos, rope_sin, positions + cache_size
        )
        _apply_rope_all_heads = jax.vmap(_apply_rope_one_head, in_axes=1, out_axes=1)
        q, k = _apply_rope_all_heads(q), _apply_rope_all_heads(k)

    # Attention computation
    t = k.shape[0]
    mask = (cache_size + jnp.arange(s))[:, None] >= jnp.arange(t)[None, :]
    mask = mask[None, ...]  # broadcast over heads
    if config.use_fa:
        attn_out = jax.nn.dot_product_attention(q, k, v, scale=None, mask=mask)
    else:
        logits = jnp.einsum("snh,tnh->nst", q, k).astype(config.compute_dtype)
        # Scale and causal mask
        logits *= 1.0 / config.d_head**0.5
        logits = jnp.where(mask, logits, -jnp.inf)
        probs = jax.nn.softmax(logits, axis=2)  # type: ignore[reportArgumentType]
        probs = probs.astype(config.param_dtype)
        attn_out = jnp.einsum("nst,tnh->snh", probs, v)
    out = jnp.einsum(
        "snh,hnd->sd",
        attn_out,
        params.w_o,
        out_sharding=jax.P(*config.sharding_res_stream),
    )

    return out, kv_cache

The changes to the _attn function are all within the code block marked with # Apply RoPE (everything else is unchanged); the _apply_rope function takes care of the necessary permutation/concatenation operation. We use JAX’s ensure_compile_time_eval context manager to ensure the RoPE multiplying sequences are precomputed. The multiplying sequences themselves are chosen to have geometrically-varying time values, with a rate parameter rope_theta set to 1e4 by default (in general, set based on the target context length, here 1024).

With RoPE, we find it much easier to learn a strong model for our toy data. After rerunning training with these modifications, we observe the following loss progression:

Output:
loss: 0.6326758861541748	step: 50
loss: 0.4087381660938263	step: 100
loss: 0.14004486799240112	step: 150
loss: 0.08325601369142532	step: 200
loss: 0.06802305579185486	step: 250
loss: 0.05945558100938797	step: 300
loss: 0.05321500450372696	step: 350
loss: 0.052829645574092865	step: 400
loss: 0.048602353781461716	step: 450
loss: 0.04696694016456604	step: 500
loss: 0.04576549679040909	step: 550
loss: 0.04564153775572777	step: 600
loss: 0.04341501370072365	step: 650
loss: 0.043622132390737534	step: 700
loss: 0.04205687716603279	step: 750
loss: 0.04401983320713043	step: 800
loss: 0.04256521165370941	step: 850
loss: 0.041673485189676285	step: 900
loss: 0.0421551875770092	step: 950

Much better! For a completely overconfident model, we might expect to have loss around $\log(2) / 256 \approx 0.002$, corresponding to randomly guessing for the first token in each sequence of the batch (we train with sequence length 256). This is within about an order of magnitude.

Building the Model: Sampling with a Cache

The implementation of sampling we gave above is extremely slow: although we jit the sample_one_token function, we are re-forwarding the entire updated sequence through the model after every generation, which is very wasteful. A quick test on our previously-implemented sampling function (on 1x TPU v4 host) takes about 50 seconds to sample 63 tokens, including the time to jit compile the sampling function.

This makes sense, given our naive implementation. Indeed, if we recall the attention definition, at the $i$-th token each head computes

\[\begin{equation*} \mathrm{softmax}( \vq_i \vK^T ) \vV, \end{equation*}\]

and the output projection just accumulates these independent heads and ‘lifts’ the head dimension $H$ to the output dimension $D$. In other words, if we were to save the previous key and value tokens $(\vk_j, \vv_j)_{j=1}^{i-1}$ in a cache, we could reuse these to compute each head’s attention output with no wasted operations: we’d just need to compute $\vk_i$ and $\vv_i$ for the current position, then load the previous positions’ keys and values from the cache to compute the full attention output.

This data structure is called a KV cache, and it’s essential for fast inference, since autoregressive operation implies we can only compute one token at a time (in contrast to training, where we forward an entire batch of sequences at once). A back-of-the-envelope analysis suggests that the naive sampling strategy (re-forward the entire sequence on every new token generation) takes $O(S^2D^2 + DS^3)$ FLOPs to generate an $S$-length sequence starting from an empty prompt, whereas the KV caching strategy takes only $O(SD^2 + DS^2)$ FLOPs. This is a significant savings!

The catch is that the KV caching strategy requires significantly more memory accesses and usage compared to the naive approach. In JAX, we also need to be careful to implement the KV cache ‘correctly’ – since the cache grows after every generation, naive implementations can entail inputs of different sizes to _attn (triggering a re-compilation every time we call it, when it’s jitted!), or dynamically-sized arrays (not allowed within a jit context!).5

Our simple implementation of the KV cache is below. Since we left the API for the cache in the code we’ve written already, we just need to implement its functionality in the _attn function, and add a function to initialize an empty cache. To play nice with jit, we will implement a static-sized cache with a fixed size config.max_seq_len (the same parameter we used for RoPE), ensuring we can jit the _attn function only once, and maintain an auxilary cache_size variable to keep track of how much we’ve written to the cache. To avoid having dynamically-sized arrays in the _attn function, we are a bit wasteful and zero-initialize the cache, then just multiply the current query $\vq_i$ against the entire cache on every attention operation.6 To avoid having these extra keys influence the attention computation (since softmaxing them produces a nonzero output!), we augment the attention mask to mask them out.

The implementation is below.

def _attn(
    config: Config,
    params: Attn,
    x_seq: jax.Array,
    kv_cache: jax.Array,  # 2 x config.max_seq_len x n x h (see below)
    cache_size: int,
):
    # s: sequence length
    # d: embedding dim (config.d_model)
    # n: attention heads (config.num_heads)
    # h: head dim (config.d_head)
    # x_seq: s x d

    qkv = jnp.einsum(
        "sd,d3nh->3snh",
        x_seq,
        params.w_qkv,
        out_sharding=jax.P(*config.sharding_att_qkv),
    )
    q, k, v = [qkv[i] for i in range(3)]
    s = q.shape[0]  # we save k shape later, after possibly prepending cache

    # Apply RoPE
    if config.use_rope:
        if not config.update_cache:
            cache_size = 0  # ignore passed value
        with jax.ensure_compile_time_eval():
            rope_cos, rope_sin = _precompute_rope_sincos(config)
        positions = jnp.arange(s, out_sharding=jax.P())
        _apply_rope_one_head = partial(
            _apply_rope, config, rope_cos, rope_sin, positions + cache_size
        )
        _apply_rope_all_heads = jax.vmap(_apply_rope_one_head, in_axes=1, out_axes=1)
        q, k = _apply_rope_all_heads(q), _apply_rope_all_heads(k)

    # Read/update/concatenate the cache
    if config.update_cache:
        k_cache, v_cache = kv_cache[0], kv_cache[1]
        k = jax.lax.dynamic_update_slice(k_cache, k, (cache_size, 0, 0))
        v = jax.lax.dynamic_update_slice(v_cache, v, (cache_size, 0, 0))
        kv_cache_out = jnp.concatenate((k[None], v[None]), axis=0)
    else:
        kv_cache_out = kv_cache
        cache_size = 0  # ignore passed value

    # Attention computation
    t = k.shape[0]
    mask = (cache_size + jnp.arange(s))[:, None] >= jnp.arange(t)[None, :]
    # with static kv cache, must mask unused memory!
    mask = mask & (jnp.arange(t)[None, :] < cache_size + s)
    mask = mask[None, ...]  # broadcast over heads
    if config.use_fa:
        attn_out = jax.nn.dot_product_attention(q, k, v, scale=None, mask=mask)
    else:
        logits = jnp.einsum("snh,tnh->nst", q, k).astype(config.compute_dtype)
        # Scale and causal mask
        logits *= 1.0 / config.d_head**0.5
        logits = jnp.where(mask, logits, -jnp.inf)
        probs = jax.nn.softmax(logits, axis=2)  # type: ignore[reportArgumentType]
        probs = probs.astype(config.param_dtype)
        attn_out = jnp.einsum("nst,tnh->snh", probs, v)
    out = jnp.einsum(
        "snh,hnd->sd",
        attn_out,
        params.w_o,
        out_sharding=jax.P(*config.sharding_res_stream),
    )

    return out, kv_cache_out


def init_kv_cache(config: Config):
    return jnp.zeros(
        (
            config.global_batch_size,
            config.num_layers,
            2,
            config.max_seq_len,
            config.num_heads,
            config.d_head,
        ),
        dtype=config.param_dtype,
        out_sharding=config.sharding_data + jax.P(None) + config.sharding_att_qkv,
    )

Running this block and then rerunning our training script shows that our changes haven’t broken anything. We can test how much it speeds up sampling, as well. We modify the sampler above slightly – we enable the cache in the config, and forward one token per sampling step to the model, rather than the entire output, as the cache will store all the context.

config_sampling_args = {"sharding_data": jax.P(), "update_cache": True}
config_sampling = Config(**config_sampling_args)

next_token = jnp.array((1,))
output = next_token
cache = init_kv_cache(config_sampling)[0]
cache_size = 0
tokens_to_generate = 63
temperature = 0.1
for step in range(tokens_to_generate):
    key_sampling, sk = jax.random.split(key_sampling)
    next_token, cache, cache_size = sample_one_token(
        config_sampling,
        sk,
        train_state.params,
        next_token,
        cache,
        cache_size,
        temperature,
    )
    output = jnp.concatenate((output, next_token))
output
Output:
Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4,
       5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8,
       9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8],      dtype=int32)

Correct outputs! The runtime is about 2 seconds including the jit compilation time on the first run, and 0.5 seconds on subsequent runs (since the static cache and one-token-input-per-sampling-step means the jitted sampler sees the same input shapes on every call). Quite a speedup!

Conclusions

In this blog post, we’ve built up:

  • A simple jittable transformer model with RoPE;
  • A basic training loop with SGD and toy data (digit alphabet);
  • A minimal jittable sampler with a static KV cache.

All in bare-metal JAX.

Since the length is already nontrivial, we’ll save further extensions – refactoring into a python package, improving the config scaffolding for running experiments and adding distributed training support, improving the optimizer, and training on actual language data – to the next post. I hope this exposition is helpful to learners of JAX and transformers – writing it has been very helpful to me!

Accumulated Configs

For reference, here’s the implementation of the Config class we’ve used throughout the discussion above.

# Architecture: config
@jax.tree_util.register_static
@dataclass(kw_only=True, frozen=True)
class Config:
    # Experiment orchestration params
    mesh_axis_names: tuple[str, ...] = ("dp",)
    mesh_shape: tuple[int, ...] = (4,)
    seed: int = 1337

    # Data and training params
    seq_len: int = 256
    global_batch_size: int = 128
    num_steps: int = 10**3
    lr: float = 1e-2

    # Model architecture params
    num_vocab: int = 10
    d_model: int = 768
    num_heads: int = 12
    d_head: int = 64
    mlp_factor: int = 4
    num_layers: int = 2
    param_std: float = 0.02
    rope_theta: float = 10000.0
    max_seq_len: int = 1024

    # Model dtypes
    param_dtype = jnp.bfloat16  # weights, activations
    compute_dtype = jnp.float32  # layernorm, attn logits, rope
    optimizer_dtype = jnp.float32  # optimizer state

    # Model call-time params
    eps_ln: float = 1e-6
    use_bias_ln: bool = False
    use_fa: bool = True
    use_bias_mlp: bool = False
    use_rope: bool = True
    update_cache: bool = False  # default training

    # Model sharding params
    sharding_data: jax.sharding.PartitionSpec = jax.P("dp")
    sharding_wqkv: jax.sharding.PartitionSpec = jax.P()
    sharding_wo: jax.sharding.PartitionSpec = jax.P()
    sharding_wup: jax.sharding.PartitionSpec = jax.P()
    sharding_wdown: jax.sharding.PartitionSpec = jax.P()
    sharding_mlp_hidden: jax.sharding.PartitionSpec = jax.P()
    sharding_res_stream: jax.sharding.PartitionSpec = jax.P()
    sharding_att_qkv: jax.sharding.PartitionSpec = jax.P()

    def __post_init__(self):
        # Set up and register mesh
        mesh = jax.make_mesh(
            self.mesh_shape,
            self.mesh_axis_names,
            len(self.mesh_shape) * (jax.sharding.AxisType.Explicit,),
        )
        jax.sharding.set_mesh(mesh)

        # Checks
        assert self.d_head % 2 == 0, (
            "Head dimension needs to be divisible by 2 for RoPE"
        )

Acknowledgments

Thanks to the TRC program for compute support.

  1. Namely, because we won’t have a clean way to combine our model parameter definitions with the actual application of the model. 

  2. The causal mask actually is enough to enable sequence-to-sequence mappings that depend on position to be learned in multi-layer transformers. The recent trend of “NoPE” (no positional encoding) in open-source large models is evidence of this. 

  3. It’s actually a block-diagonal rotation matrix, with $2 \times 2$ blocks. This is what leads to the efficient approach to computing its induced quadratic form below! 

  4. In code, for a list v, this is just v[H//2:] + v[:H//2]. Moreover, the choice of the permutation $\vPi$ is not essential (because the coordinates of the embedded keys and queries follow a learnable linear transformation), and can be changed for either mathematical clarity or code clarity (we choose based on the latter; the original RoPE paper emphasized the former). 

  5. Actually, our naive sampling implementation above, which re-forwards the entire sequence after concatenating, runs so slow because the input to the transformer keeps growing in size – which triggers a recompilation after every generation! Avoiding these pitfalls is a general challenge in JAX. 

  6. Although this feels very wasteful, it actually only loses (asymptotically) a factor of $2$ in FLOPs versus the optimal implementation.