In this third post in a series on programming with JAX for training transformers, we pick up where we left off at the end of the previous post. We’ll extend our previous small transformer model to the NanoGPT scale, incorporating actual text data, distributed training, and a real optimizer, and perform some basic evaluation! We’ll end up with a reasonable substrate for building future experiments on top of.

As in previous posts in the series, we make good use of the Google scaling playbook (Austin et al., 2025) for profiling.

Package Conversion and Improvements

We’re running in Python 3.13, and we’ve upgraded to JAX 0.8.0 from previous posts. Now that we’ve written a good chunk of model, training, and config code, we’ll also be working on top of a packaged version of our code.

The code is hosted in this GitHub repository: baremetal-gpt. There are instructions to install the repository as a package using uv there.

We’ll start by talking through some refactoring that leads us from the state of our code in the previous post to the state of the code in the GitHub repository.

Config Upgrades

To run experiments with our code in a distributed (i.e., multi-host) setting, we’ll benefit from having a robust way to dispatch different experimental configurations from the command line. We’ll use Hydra for this purpose. Hydra is not the most minimal solution, but it’s the most feature-complete library for managing configs that I know of (e.g., specifying lists as arguments), and for our simple purposes we’ll be able to adopt it with only a small amount of added code and complexity.

Small Modifications to our Config Dataclass

Hydra supports registering config dataclasses as so-called “Structured Configs”. This tutorial gives a good description of the core functionalities we need. We only need to make a few small modifications to our previous dataclass:

  1. Use only supported types: instead of random class objects (e.g., for dtypes or PartitionSpecs), we need to create Enums for passing values, and disambiguate them at call time. There are other minor caveats (e.g., can’t arbitrarily nest lists; restrictions on use of Optional).
  2. Make the config mutable (remove frozen=True).
  3. Re-register the config as static with JAX. Hydra will wrap our dataclass so that it ends up duck-typed as a Config object, but we will need to re-register the wrapped config as static. Since the dataclass is no longer frozen, we can do this with a post-init function, which we’ll now call when we launch our training loop.
  4. Create a function to register the Config dataclass in a Hydra ConfigStore instance. In general, we can create arbitrarily stratified configs for easier management (e.g., separate configs for model, experiment, data, etc.). We’ll call this function in our main method.
  5. Refactor the training loop into a train.py file with a main method that we’ll call when we want to launch an experiment, and add the necessary Hydra setup code.

We will also make a larger refactor, which will help us later with setting up config YAML files for experiments: we will change from a flat config to a hierarchical config, with different parameters split into different sub-configs (e.g., for the optimizer, for data, …). This is not an objectively superior design – for example, it’s arguably a bit easier to write new code and integrate static type checking with a flat config – but we’ll follow it in order to add a bit more encapsulation to distinct config parameters.

Here’s a representative snipped from the new config.py with the modifications discussed above. There are many forward-looking modifications too, which we just show out-of-context selections of.

# config.py

# We omit the sub-configs, and other nitty-gritty details...
# Config is now top-level (orchestration parameters, etc.)


@dataclass(kw_only=True, unsafe_hash=True)
class Config:
    """Overall config class, containing top-level orchestration parameters"""

    seed: int = 1337
    logger_type: LoggerType = LoggerType.WANDB
    project_name: str = "bmgpt-debug"
    run_name: str = ""
    val_log_interval: int = 1000  # log validation metrics every <this many> batches

    train_dataset: DatasetConfig = MISSING
    val_list: list[EvaluationConfig] = field(default_factory=list)  # validation metrics
    eval_list: list[EvaluationConfig] = field(default_factory=list)

    optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    inference: InferenceConfig = field(default_factory=InferenceConfig)
    sharding: ShardingConfig = field(default_factory=ShardingConfig)


def register_configs():
    cs = ConfigStore.instance()
    cs.store(group="optimizer", name="base_optimizer", node=OptimizerConfig)
    cs.store(group="model", name="base_model", node=ModelConfig)
    cs.store(group="inference", name="base_inference", node=InferenceConfig)
    cs.store(group="sharding", name="base_sharding", node=ShardingConfig)
    cs.store(name="base_config", node=Config)


def config_post_init(config: Config):
    """Input validation."""
    # ...

And here’s the relevant parts of train.py that we modify. There is some extra code for distributed initialization, which we discuss in more detail later in the post.

from pathlib import Path

import hydra
from bmgpt.config import Config, config_post_init, register_configs

register_configs()

# ...


@hydra.main(
    version_base=None,
    config_path=str(Path("configs").absolute().resolve()),
    config_name="base_config",
)
def main(config: Config):
    try:
        # Launch distributed and register configs
        jax.distributed.initialize()
        jax.tree_util.register_static(type(config))
    except RuntimeError:
        # This implies the distributed backend has already been initialized
        pass
    config_post_init(config)


if __name__ == "__main__":
    main()

The modifications are quite simple: we register our Config dataclass in the ConfigStore instance as 'config', then decorate our main method in train.py to use that appropriate config. The 'config_path' argument in that decorator specifies a directory where we store different .yaml config files specifying different experimental setups (e.g., different model architectures, different datasets, different distributed training environments, etc.). Then we just call our post-init config method right at the start of main: right now, it just validates some config inputs.

When we want to launch an experiment now, we can run a command like the following (train loop refactoring follows later in the blog):

uv run --extra tpu train model.num_layers=1

Here, we are using uv optional dependencies so that we can run on either CPU or TPU depending on the platform, as well as a build script in our package (see the pyproject.toml). Super easy! Anything that isn’t specified follows our defaults from the Config dataclass definition.

As the experiments we want to run become more diverse and complex, it becomes harder and harder to productively specify all default overrides via the command line. To address this, we can write YAML configs for different experiments as we mentioned above. For example, we can make a file at configs/experiments/text_toy.yaml, with:

# @package _global_

defaults:
  - /dataset/staircase@train_dataset

train_dataset:
  num_steps: 100

eval_list:
  - dataset:
      name: NUMBER_STAIRCASE
      path: ""
      split: TEST
      seq_len: 1  # At inference time, this functions like prompt size
      global_batch_size: 16
    evaluator: AUTOREGRESSIVE_ROLLOUTS

model:
  transformer_type: DISCRETE
  is_causal: True
  num_vocab: 10

See the Hydra docs for more details about how to parse this file – it has some boilerplate to ensure it overrides the config defaults we specify in our config.py. We’ll discuss some of the enums that appear here (you can identify them by the all-caps text) later on. Then to run an experiment, we can simply override as:

uv run --extra tpu train +experiment=text_toy

In addition, any other command line arguments we want to override with can be added!

Model Improvements and Refactors

We’ll add support for the Pallas Splash Attention kernel, which will help us with model scaling later!

Splash Attention Integration

Splash Attention stands for “sparse Flash Attention” (Pagliardini et al., 2023). Flash Attention is a well-known memory-efficient implementation of the attention operation (Dao et al., 2022); it also is often used to refer to the associated CUDA kernel and surrounding PyTorch package for the algorithm. Compared to flash attention, splash attention adds better support for sparse attention masks, which arise frequently in practice: for example, when we want to pack many relatively short sequences into one long sequence, or precisely mask between document boundaries, we need to make sure not to attend across these boundaries.

The types of optimizations that flash/splash attention entail require a level of control over hardware-level execution that is not possible within the standard JAX-XLA programming pipeline. So JAX exposes a domain-specific language for writing these kernels, called Pallas! The authoring of Pallas kernels is out-of-scope for this post, but we will be interested in learning how to call pre-written Pallas kernels so that we can integrate splash attention.

The interfaces for achieving this are naturally a bit rough around the edges, with scant documentation and relatively few examples for how to proceed. With the help of multiple LLMs, I eventually pieced together a solution based on example code available in the JAX Pallas tests. Here’s a high-level overview, with the code to follow:

  • We need to configure the Pallas kernel itself based on how we’ll use it at runtime: specifically, the sequence length, number of heads, base mask type, and parallelism configuration. The latter has to be enforced manually with shard_map. Once we figure out where everything needs to go, this works more-or-less okay with our current model implementation structure, which uses Explicit mesh axes and vmap as much as possible to simplify ‘inner’ model code. However, it currently doesn’t work in JAX 0.8.1. It should work in the next release, thanks to a bug report I submitted and a fix from Yash Katariya (thanks!).
  • Once we have the kernel (represented as a tuple of two Callables, one corresponding to the actual kernel and one corresponding to a wrapper with the shard_map configuration), we need to thread it through our top-level _transformer call down to the _attn call. A more ‘elegant’ approach to doing this (at least at the call site) might be to use some metaprogramming, but we will go with the simplest solution.

The code for the kernel setup is in splash_helpers.py:

# splash_helpers.py
from functools import partial

import jax
from bmgpt.config import Config, DatasetConfig
from jax.experimental.pallas.ops.tpu.splash_attention import (
    BlockSizes,
    CausalMask,
    MultiHeadMask,
    make_splash_mha,
)

# ...


def make_splash_kernel(
    config: Config,
    config_data: DatasetConfig,
    cache_capacity: int,
    mesh,
    head_shards: int = 1,
    q_seq_shards: int = 1,
):
    if not config_data.use_splash:
        # None ends up calling jax-xla attention
        # see _attn
        return None
    # s is Q len (seq_len @ train; variable/1 at prefill/decode)
    # t is K len (s + cache_capacity)
    s = config_data.seq_len
    t = s + cache_capacity

    if config.model.is_causal:
        mask = MultiHeadMask(
            [
                CausalMask(shape=(s, t), offset=cache_capacity)
                for _ in range(config.model.num_heads)
            ]
        )
    else:
        mask = MultiHeadMask(
            [FullMask(shape=(s, t)) for _ in range(config.model.num_heads)]
        )
    BLOCK_SIZE = min(config_data.seq_len, 128)
    if config_data.seq_len % 128 != 0 or BLOCK_SIZE % 128 != 0:
        # splash attention kernel requires block size to be a multiple of 128
        raise NotImplementedError("Splash block size needs to be a multiple of 128")
    block_sizes = BlockSizes(
        block_q=BLOCK_SIZE,
        block_kv=BLOCK_SIZE,
        block_kv_compute=BLOCK_SIZE,
        block_q_dkv=BLOCK_SIZE,
        block_kv_dkv=BLOCK_SIZE,
        block_kv_dkv_compute=BLOCK_SIZE,
        block_q_dq=BLOCK_SIZE,
        block_kv_dq=BLOCK_SIZE,
    )
    splash_spec = jax.P(None, None)
    splash_sharding = jax.sharding.NamedSharding(mesh, splash_spec)
    kernel = make_splash_mha(
        mask,
        head_shards=head_shards,
        q_seq_shards=q_seq_shards,
        block_sizes=block_sizes,
    )
    kernel_spec = kernel.manual_sharding_spec(splash_sharding)

    @partial(
        jax.shard_map,
        mesh=mesh,
        in_specs=(kernel_spec, splash_spec, splash_spec, splash_spec, jax.P()),
        out_specs=splash_spec,
        check_vma=False,
    )
    def splash_sharded(kernel, q, k, v, segment_ids):
        return kernel(q, k, v, segment_ids=segment_ids)

    return (splash_sharded, kernel)

The function argument cache_capacity specifies the maximum KV cache size we’ll use (we’ve split this off from max_seq_len which filled a dual function encompassing this previously). This allows us to specify a cache size of 0 for training, for example. Within the function body, we have a new configuration parameter use_splash to denote whether splash attention should be used or not. For example, we always fall back to XLA attention for autoregressive inference, since the splash attention Pallas kernel requires the block size to be a multiple of 128 (and for decode, we almost necessarily have to process a query block size of 1). In this case, the kernel creation factory just returns None, and we check for this in _attn.1

The “sparse” part of splash attention comes from the proper use of the segment_ids argument: distinct documents/sub-sequences can be assigned different segment_ids. We’ll see an example below, in how to set this up for use during training with respect to the KV cache.

First, at the level of train.py, we now create our kernels when we set up. For training, this leads to a simple kernel = make_splash_kernel(config, config.train_dataset, 0, mesh) statement. We then pass the kernel to any model call we have: for example, logits, _ = jax.vmap( partial(_transformer, config, kernel, params, cache_params=cache_params))(inputs, train_state.kv_cache) in the training loop step. Here, we’ve refactored the way we use the KV cache – a CacheParams class is defined simply as follows:

class CacheParams(NamedTuple):
    enabled: bool
    size: int

This way, as we’ll see momentarily, we can easily trace with respect to the size argument (i.e., for inference) while still omitting all KV cache maintenance code at compile time if we want to train (i.e., by setting enabled=False).

The kernel argument and this new cache_params argument get threaded down to the level of _attn, which now reads as follows (abbreviating unchanged components):

# model.py
def _make_causal_mask(seq_len_q: int, seq_len_k: int, cache_size: int) -> Array:
    """(seq_len_q, seq_len_k) bool array with True for past token positions"""
    q_positions = cache_size + jnp.arange(seq_len_q)
    k_positions = jnp.arange(seq_len_k)
    return q_positions[:, None] >= k_positions[None, :]


def _make_cache_mask(seq_len_q: int, seq_len_k: int, cache_size: int) -> Array:
    """(1, seq_len_k) bool array with True for actual cache+context positions"""
    k_positions = jnp.arange(seq_len_k)
    return k_positions[None, :] < cache_size


def _attn(
    config: Config,
    kernel,
    params: Attn,
    x_seq: jax.Array,
    kv_cache: jax.Array,  # 2 x cache_capacity x n x h
    cache_params: CacheParams,
):
    # s: sequence length
    # d: embedding dim (config.d_model)
    # n: attention heads (config.num_heads)
    # h: head dim (config.d_head)
    # x_seq: s x d

    q, k, v = jnp.einsum(
        "sd,d3nh->3nsh",
        x_seq,
        params.w_qkv,
        out_sharding=jax.P(*config.sharding.att_qkv),
    )
    s = q.shape[1]

    # skipping rope...

    k_cache, v_cache = kv_cache[0], kv_cache[1]
    if cache_params.enabled:
        k_cache_out = jax.lax.dynamic_update_slice(
            k_cache, k, (0, cache_params.size, 0)
        )
        v_cache_out = jax.lax.dynamic_update_slice(
            v_cache, v, (0, cache_params.size, 0)
        )
        kv_cache_out = jnp.concatenate((k_cache_out[None], v_cache_out[None]), axis=0)
    else:
        kv_cache_out = kv_cache
    # Cache read scheme: to enable same mask for the same s value (Q seq len),
    #  we concatenate the full cache to K, and mask empty entries
    # For efficient training, set cache size zero + cache_params.enabled=False
    cache_capacity = k_cache.shape[-2]
    k = jnp.concatenate((k_cache, k), axis=1)
    v = jnp.concatenate((v_cache, v), axis=1)

    # Attention
    t = k.shape[1]  # t = s + cache_capacity
    if kernel:
        q_segment_ids = jnp.zeros((s,))
        kv_mask = _make_cache_mask(s, t, cache_params.size) | (
            ~_make_cache_mask(s, t, cache_capacity)
        )
        kv_mask = ~kv_mask[0]
        kv_segment_ids = kv_mask.astype(jnp.int32)
        segment_ids = SegmentIds(q=q_segment_ids, kv=kv_segment_ids)

        splash_sharded, kernel = kernel
        attn_out = splash_sharded(
            kernel,
            q / config.model.d_head**0.25,
            k / config.model.d_head**0.25,
            v,
            segment_ids,
        )
    else:
        # Make mask
        if config.model.is_causal:
            mask = _make_causal_mask(s, t, cache_capacity)
            cache_mask = _make_cache_mask(s, t, cache_params.size) | (
                ~_make_cache_mask(s, t, cache_capacity)
            )
            mask = mask & cache_mask
        else:
            mask = ~_make_cache_mask(s, t, 0)  # full attention
        mask = mask[None, ...]  # broadcast over heads
        # Scale and causal mask
        logits = jnp.einsum("nsh,nth->nst", q, k).astype(
            config.model.compute_dtype.value
        )
        logits *= 1.0 / config.model.d_head**0.5
        logits = jnp.where(mask, logits, -jnp.inf)
        probs = jax.nn.softmax(logits, axis=2)  # type: ignore[reportArgumentType]
        probs = probs.astype(config.model.param_dtype.value)
        attn_out = jnp.einsum("nst,nth->nsh", probs, v)
    out = jnp.einsum(
        "nsh,hnd->sd",
        attn_out,
        params.w_o,
        out_sharding=jax.P(*config.sharding.res_stream),
    )

    return out, kv_cache_out


# ... skipping to the cache init function


def init_kv_cache(config: Config, global_batch_size: int, cache_capacity: int):
    if not config.sharding.data:
        sharding_batch_layer = [None, None]
    else:
        sharding_batch_layer = config.sharding.data + [None]
    sharding = jax.P(*(sharding_batch_layer + config.sharding.att_qkv))

    return jnp.zeros(
        (
            global_batch_size,
            config.model.num_layers,
            2,
            config.model.num_heads,
            cache_capacity,
            config.model.d_head,
        ),
        dtype=config.model.param_dtype.value,
        out_sharding=sharding,
    )

The splash Pallas kernel expects the shape to be N x S x H, so we refactor some code compared to what we had before. The key change here is that in order to repeatedly reuse the same splash Pallas kernel as we grow the cache, we write the code so that the KV sequence length is always fixed as the cache capacity plus S (the Q sequence length). The cache capacity is specified dynamically at initialization, allowing us to use the same Config object for different training/inference scenarios (e.g., cache size 0 for training, and some fixed positive integer for autoregressive inference). This means we need to construct a SegmentIds object that masks out all the unused portions of the KV cache, which is performed correctly by the above code. Previously, we were just updating the cache with jax.lax.dynamic_update_slice, rather than concatenating, but with the splash kernel structure, this would require us to change the mask after every additional cache update (requiring a new kernel!).

That’s it! We’ve integrated splash attention, allowing us to scale much more effectively to larger model sizes down the road.

Optimizer Improvements

We were using 100% vanilla SGD in the previous post. Let’s upgrade that to a much stronger optimizer, AdamW! (Loshchilov & Hutter, 2017)

In our previous post, we used a very simple tree-mapped update, facilitated by the simple structure of constant-learning-rate SGD:

# ...
loss, grad = jax.value_and_grad(loss_fn)(train_state.params)
new_state = TrainState(
    params=jax.tree.map(lambda x, y: x - config.lr * y, train_state.params, grad),
    opt=None,
)
# ...

For implementing Adam and AdamW (the “W” denotes “decoupled weight decay”), it’s better to use a more flexible API: both because the updates become harder to express in a single concise lambda, and because we generally want to be able to apply different updates to different parameters.2 Here’s how we’ll do this, at a basic level:

  1. We’ll define a handfull of optimizer updates that share the same API.
  2. We’ll allow one update to be selected via the Config class (i.e., with the Enum we mentioned above), and then we’ll jax.tree.map it over the model pytree.
  3. We’ll be able to provide a pytree of arguments to the update, as necessary, allowing to do different things for different parameters.

For example, here’s how this looks for the previous constant-learning-rate SGD update. Since AdamW requires us to save state information (i.e., first-order and second-order momentum buffers), we’ll incorporate that into the API.

# Factory for getting the update, used externally
def opt_update_factory(opt_type: OptType):
    match opt_type:
        case OptType.ADAMW:
            return adamw_update
        case OptType.SGD:
            return sgd_update


class OptState(NamedTuple):
    mu: jax.Array  # 1st moment EMA
    nu: jax.Array  # 2nd moment EMA
    step: jax.Array  # step number


def init_adam_state(config: Config, param: jax.Array) -> OptState:
    return OptState(
        mu=jnp.zeros_like(param, dtype=config.optimizer_dtype.value),
        nu=jnp.zeros_like(param, dtype=config.optimizer_dtype.value),
        step=jnp.array(0, dtype=jnp.int32),
    )


def sgd_update(
    config: Config, param: jax.Array, grad: jax.Array, state: OptState, wd_mask: bool
):
    update = -config.lr * grad
    if wd_mask:
        # Apply weight decay
        update = update - config.lr * config.weight_decay * param
    return update, state

The update API takes parameters, gradients, state, and additional flags, and returns the update direction and the updated state. We add the update direction to the parameters externally since this is a useful debugging abstraction (and matches the optax API).

Here’s how it looks to apply the optimizer now. It gets a bit more cumbersome to do this with only pytrees, since the update returns a tuple – meaning the output pytree structure is different from the model/grad pytree structure. The latter is a prefix of the former, though, so we can still extract what we need fairly easily.

# In train.py

# ...
class TrainState(NamedTuple):
    params: Transformer
    opt_state: Any
    kv_cache: jax.Array


@jax.jit
def init_train_state(key, config: Config) -> TrainState:
    model_params = init_transformer(key, config)
    adam_state = jax.tree.map(partial(init_adam_state, config), model_params)
    cache = init_kv_cache(config, config.train_dataset.global_batch_size, 0)
    return TrainState(params=model_params, opt_state=adam_state, kv_cache=cache)


# ...


def main(config: Config):
    # ...
    opt_update = opt_update_factory(config.optimizer.type)

    # ...
    @partial(jax.jit, donate_argnums=2)
    def train_step(config: Config, batch, train_state: TrainState):
        # ...
        loss, grad = jax.value_and_grad(loss_fn)(train_state.params)
        # ...
        update__opt_state = jax.tree.map(
            partial(opt_update, config),
            train_state.params,
            grad,
            train_state.opt_state,
            weight_decay_mask,
        )
        # Transpose the output tree to get update tree and state tree
        update, opt_state = map(
            lambda i: jax.tree.map(lambda x, y: y[i], grad, update__opt_state),
            range(2),
        )
        params = jax.tree.map(lambda x, y: x + y, train_state.params, update)
        new_state = TrainState(
            params=params, opt_state=opt_state, kv_cache=train_state.kv_cache
        )
        # ...

This is pretty straightforward, once we’ve written it!

Now, to add a new optimizer – AdamW in our case – we just need to define the appropriate opt_update and possibly define new state initialization functions. We’ve already done the latter above for Adam. Here’s a simple implementation of AdamW that agrees with the optax implementation (not bitwise):

def adamw_update(
    config: Config, param: jax.Array, grad: jax.Array, state: OptState, wd_mask: bool
):
    beta1 = config.beta1
    beta2 = config.beta2
    lr = config.lr
    eps = config.eps_adam
    weight_decay = config.weight_decay

    mu = beta1 * state.mu + (1 - beta1) * grad.astype(config.optimizer_dtype.value)
    nu = beta2 * state.nu + (1 - beta2) * grad.astype(config.optimizer_dtype.value) ** 2
    new_state = OptState(mu=mu, nu=nu, step=state.step + 1)

    mu_debias = mu / (1 - beta1**new_state.step)
    nu_debias = nu / (1 - beta2**new_state.step)
    update = -lr * mu_debias / (eps + jnp.sqrt(nu_debias))
    if wd_mask:
        # Apply weight decay
        update = update - lr * weight_decay * param
    return update.astype(config.param_dtype.value), new_state

The slightly tricky part remaining is how to pass the right weight decay mask pytree. It’s best to do this with model metadata, since the format we store the parameters in doesn’t necessarily match how they’re used. For example, we store our attention QKV parameters as a D x 3 x N x H array, where D is the model dimension, N the number of attention heads, and H the projected attention head dimension. We use these parameters as D x H matrices, but we can’t safely infer this from only the shape information! Similarly, since we vmap our Blocks initialization over the layer dimension L when constructing the model, our MLP output bias parameters are shaped like L x D. This is a matrix, so when iterating over the model pytree, we can’t just check whether the current leaf has ndim > 1!

This is a situation where we clearly see the advantage of the bigger, standard model libraries (e.g. Equinox): there, metadata and parameters can easily be stored together. For us, we’d need to depart from our appealingly simple NamedTuple model architecture to do this (e.g. register custom pytree leaves). So we’ll instead use a slightly unpleasant approach: we’ll just add a function to our model.py file that returns a model “spec” associated to the architecture, containing all relevant metadata.3

# in model.py
def model_spec(model: Transformer) -> Any:
    # Make the spec (we need some way to pass metadata around)
    def _make_spec_from_str(path: str) -> tuple[int, int] | None:
        param_str = path[-1].__str__()
        matrix_axes_dict = {
            ".w_qkv": (-4, -1),
            ".w_o": (-3, -1),
            ".w_up": (-2, -1),
            ".w_down": (-2, -1),
            ".w": (-2, -1),
            # ...
        }
        return matrix_axes_dict.get(param_str, None)

    spec = jax.tree.map_with_path(lambda p, _: _make_spec_from_str(p), model)
    return spec

The keys in the dictionary used to generate the output pytree correspond to the parameter names in our Transformer model. We can get these from the pytree using jax.tree.map_with_path, which gives the pytree structure (easily converted into a string, like .blocks.norm_attn.gamma) in addition to the leaves.

The returned pytree looks like the following:

Transformer(emb=Embedding(w=(-2, -1)), blocks=Block(norm_attn=LayerNorm(gamma=None, beta=None), attn=Attn(w_qkv=(-4, -1), w_o=(-3, -1)), norm_mlp=LayerNorm(gamma=None, beta=None), mlp=Mlp(w_up=(-2, -1), bias_up=None, w_down=(-2, -1), bias_down=None)), unemb=Unembedding(w=(-2, -1)))

Now, back to the matter at hand: we need to generate a mask pytree that tells us which parameters in the model to apply weight decay to. To apply it only to the matrix parameters in the model, we can do the following:

spec = model_spec(train_state.params)
weight_decay_mask = jax.tree.map(lambda x, s: bool(s), train_state.params, spec)

Since None is not truthy, this will give us the mask we desire! Then we use it as above, where we already provided weight_decay_mask as an argument to the update calculation tree map (but didn’t specify how to generate it).

Here’s the entire updated train step. We add gradient clipping to it, since this doesn’t require too much effort:

# in optimizers.py
def grad_norm_and_clip(
    config: Config, model: Transformer
) -> tuple[Transformer, Transformer, float]:
    # Gradient norms in param precision (NOTE: might want fp32?)
    grad_norms_squared = jax.tree.map(lambda grad: jnp.sum(grad**2), model)
    global_grad_norm = jax.tree.reduce(operator.add, grad_norms_squared) ** 0.5
    truncated_norm = jax.lax.select(
        global_grad_norm >= config.clip_grad,
        global_grad_norm,
        jnp.ones_like(global_grad_norm),
    )
    return (
        jax.tree.map(lambda grad: grad / truncated_norm, model),
        grad_norms_squared,
        global_grad_norm,
    )

It’s quite straightforward overall!

  # Initialize state, configure forward pass and optimization
  with jax.set_mesh(mesh):
      train_state = init_train_state(key_model, config)
  cache_params = CacheParams(enabled=False, size=0)
  kernel = make_splash_kernel(config, config.train_dataset, 0, mesh)
  spec = model_spec(train_state.params)
  opt_update = opt_update_factory(config.optimizer.type)
  weight_decay_mask = jax.tree.map(lambda _, s: bool(s), train_state.params, spec)


  @partial(jax.jit, donate_argnums=2)
  def train_step(config: Config, batch, train_state: TrainState):
      def loss_fn(params: Transformer):
          inputs, targets = batch
          logits, _ = jax.vmap(
              partial(_transformer, config, kernel, params, cache_params=cache_params)
          )(inputs, train_state.kv_cache)
          logits = logits.astype(config.model.compute_dtype.value)
          logprobs = jax.nn.log_softmax(logits, axis=-1)
          return -jnp.take_along_axis(logprobs, targets[..., None], axis=-1).mean()

      loss, grad = jax.value_and_grad(loss_fn)(train_state.params)
      grad_clipped, _, global_grad_norm = grad_norm_and_clip(config, grad)
      update__opt_state = jax.tree.map(
          partial(opt_update, config),
          train_state.params,
          grad_clipped,
          train_state.opt_state,
          weight_decay_mask,
      )
      # Transpose the output tree to get update tree and state tree
      update, opt_state = map(
          lambda i: jax.tree.map(lambda x, y: y[i], grad, update__opt_state), range(2)
      )
      params = jax.tree.map(lambda x, y: x + y, train_state.params, update)
      new_state = TrainState(
          params=params, opt_state=opt_state, kv_cache=train_state.kv_cache
      )

      metrics = {"batch_loss": loss, "grad_norm": global_grad_norm}
      return metrics, new_state

Distributed Training

Next, let’s set up distributed training. On the whole, this is actually very easy to do on Google’s TPU VMs!4 This is because:

  1. The VMs are already configured with the IP addresses / ICI information of the other hosts in the slice.
  2. We’re using jax.sharding.AxisType.Explicit sharding, and our model is relatively simple. So communication gets handled by the XLA compiler! We don’t actually need to use pmap or pjit anywhere.

For our purposes, the only part that requires care is data management, particularly data loading. We’ll describe this below for our simple synthetic data distribution, as well as the other changes we need to make for distributed training.

High-Level Paradigm

In JAX distributed training, each host (or “process”) is physically connected to some number of chips via PCIe. On v4 TPUs, which we’ve been using, each host is connected to 4 chips (each chip with 2 cores). When working with TPU VMs, we can think of one host as corresponding to one VM that we can ssh into; in our tests previously, we’ve been working on a single host.

Different hosts are arranged into “slices”. Within a slice, chips are interconnected via high-bandwidth optical links, letting them communicate relatively efficiently! When it comes to data movement from the host to the device, though, each host can only directly communicate with the chips it is physically connected to. These are called the “addressable” devices of that host.

With jax.distributed, we can programmatically handle these differences while only writing a single program! Our program gets run identically on each host, but we have access to certain functions that can allow us to branch and shard differently depending on which host we are. Key functions that we’ll see below are:

  • jax.distributed.initialize(), which synchronizes information about all interconnected devices among the different hosts;
  • jax.process_index(), which returns a different integer for each host (allowing branching);
  • jax.Array.addressable_shards, jax.Array.is_fully_addressable (etc.), which can be used to programmatically determine which devices can directly modify a given array. For example, if x is a jax.Array that is fully replicated across devices (with jax.P() as its partition spec), then x.is_fully_addressable is False in a multi-host setting! We want to avoid having our program modify an array in a way that is not compatible with either its sharding or the devices with respect to which it is addressable.

For more detailed information, see Chapter 2 of the TPU scaling book (Austin et al., 2025) as well as relevant sections of the JAX documentation (JAX Team, 2025).

Setting up the Mesh and Initializing the Distributed Backend

One multi-host footgun I learned about the hard way is that setting a global mesh with jax.set_mesh, as we have been doing so far, makes it hard to create arrays that exist only on the local process (more on that below). So we switch to creating a mesh from our config options at the start of the main function, then instantiate it as a context manager when we need it. This turns out to be pretty easy—we don’t need to pass the mesh to all parts of our model’s forward pass and add it as an additional parameter to out_sharding (etc.)!

Here’s the new setup (which we saw part of above):

# in config.py
def mesh_from_config(config: Config):
    mesh = jax.make_mesh(
        config.sharding.mesh_shape,
        config.sharding.mesh_axis_names,
        len(config.sharding.mesh_shape) * (jax.sharding.AxisType.Explicit,),
    )
    return mesh


# in train.py
@hydra.main(
    version_base=None,
    config_path=str(Path("configs").absolute().resolve()),
    config_name="base_config",
)
def main(config: Config):
    try:
        # Launch distributed and register configs
        jax.distributed.initialize()
        jax.tree_util.register_static(type(config))
    except RuntimeError:
        # This implies the distributed backend has already been initialized
        pass
    config_post_init(config)
    mesh = mesh_from_config(config)
    # ...

That’s it! We just have to be sure to set the config.mesh_shape parameter appropriately when we launch a job (see below), as well as all our desired sharding settings. Our config makes simple data parallel training over the entire list of devices the default.

Modifications for Data Loading

We refactor the data generation code into data.py. At a high level, here’s what we end up doing in train.py to set up the data loading:5

# in train.py
# ...
# Randomness
key = jax.random.key(config.seed)
key_model, key_train, key_val, key_eval = jax.random.split(key, 4)

# Data
batch_iter = get_distributed_batch_iter(config, config.train_dataset, key_train, mesh)
# ...

We abstract the previous dataset-specific code into this factory function, which calls the appropriate sub-functions (letting us reuse this for val/eval):

# in data.py


def dataset_dataloader_factory(config: DatasetConfig):
    match config.name:
        case DatasetName.MNIST:
            return (load_mnist(config), dataloader_without_replacement)
        case DatasetName.NUMBER_STAIRCASE:
            return (
                make_number_staircase_data(config),
                dataloader_with_replacement,
            )
        case DatasetName.SHAKESPEARE:
            return (
                load_shakespeare(config),
                dataloader_with_replacement,
            )


# Helper to just get the global batch iterator, if we don't need the local data/loader
def get_distributed_batch_iter(
    config: Config, dataset_config: DatasetConfig, key, mesh
):
    data, dataloader = dataset_dataloader_factory(dataset_config)
    return get_dataset_on_device(config, dataloader(key, dataset_config, data), mesh)

Distributed data loading has a large number of potential footguns; the JAX Advanced Guides listing has a useful entry about this (JAX Team, 2025), and see also the article about explicit sharding we used in previous blogs (JAX Team, 2025). In the present setting, we take advantage of the fact that we can load the entire dataset into memory on each host. The following code does this:

def make_number_staircase_data(config: DatasetConfig):
    # Simple data sequence!
    # Small, so mega replicate for ease
    text = "012345678987654321" * 1024
    ids = [int(c) for c in text]
    data = jnp.array(ids, dtype=jnp.int32)
    Xtr, Xdev, Xte = split_data(data, 0.8, 0.1)
    match config.split:
        case SplitType.TRAIN:
            return Xtr, jnp.array(0)
        case SplitType.TEST:
            return Xte, jnp.array(0)
        case SplitType.VAL:
            return Xdev, jnp.array(0)


def split_data(data: Array, train_fraction: float, dev_fraction: float):
    num_data = len(data)
    Xtr = data[: int(train_fraction * num_data)]
    Xdev = data[
        int(train_fraction * num_data) : int((train_fraction + dev_fraction) * num_data)
    ]
    Xte = data[int((train_fraction + dev_fraction) * num_data) :]
    return Xtr, Xdev, Xte

The tricky part here is to note that as configured, the above code generates the data array so that it is locally addressable by every process: each host creates its own (identical) data array. More precisely, if we evaluate data.is_locally_addressable, it will be True. This must be contrasted with the case of a fully replicated data array, which we don’t want here—our approach will be to piece together a global batch out of local batches on each device, and having a locally addressable data array is (currently) the canonical way to do this!

Next, we need to make sure each host has access to an independent and identically distributed stream of batches. This gets handled by the get_dataset_on_device and dataloader functions.

# in data.py

# Simple next-token prediction dataloader for text training:
# - Each host has the whole dataset (data input)
# - Random sampling without replacement to draw batches (consumes key)
# - Targets are the next token (expect data to be (num_data,) shape)
def dataloader_with_replacement(
    key, config: DatasetConfig, data: tuple[Array, Array]
) -> DataloaderOutputType:
    inputs, _ = data
    num_data = len(inputs)
    key = jax.random.fold_in(key, jax.process_index())
    for step in it.count():
        key = jax.random.fold_in(key, step)
        offsets = jax.random.randint(
            key,
            (config.global_batch_size // jax.process_count(),),
            0,
            num_data - config.seq_len - 1,
        )
        seqs = inputs.at[offsets[:, None] + jnp.arange(config.seq_len)].get()
        targets = inputs.at[1 + offsets[:, None] + jnp.arange(config.seq_len)].get()
        yield seqs, targets


# Helper to map a dataloader with make_array_from_process_local_data
def get_dataset_on_device(
    config: Config, dataloader: DataloaderOutputType, mesh: Mesh
) -> DataloaderOutputType:
    return map(
        lambda batch: jax.make_array_from_process_local_data(
            NamedSharding(mesh, jax.P(*config.sharding.data)), batch
        ),
        dataloader,
    )

We must pass the mesh we created earlier here in order to correctly create the global batch out of the local shards that the dataloader generates. Explicit sharding seems to propagate to outputs, so the data.at operations return locally-addressable batch arrays on each process, which make_array_from_process_local_data pieces together into the global array. To generate independent local batches on each process, we are sure to fold_in the process index to the RNG in order to create a distinct stream of random offsets.

This quick-and-dirty data loader, which does not shuffle nor guarantee sampling without replacement for batches across processes, is fine for our purposes on this synthetic dataset. It should also be okay for use with much larger datasets, when training for a relatively small number of tokens. But it’s worth keeping in mind that something more robust would be needed in general.6

Remaining Cleanup

All that’s left is to judiciously insert our mesh context manager at key points of the training loop. We saw above (when discussing optimizers) that we need to add it to the model initialization call. Other than that, we just need to add it to the model call in our training loop:

# in train.py

# ...
for step in range(config.num_steps):
    batch = next(batch_iter)
    with jax.set_mesh(mesh):
        cur_metrics, train_state = train_step(config, batch, train_state)
    # ...

For simplicity’s sake, we also wrap our sampling code with it.

Launching Everything

Given our DIY ethos, we handle launching distributed jobs on our TPU VM with an ad-hoc launch script:

# run.sh

TPU_NAME="$1"
shift
SSH_FLAGS='-A -o ForwardAgent=yes'
COMMANDS="if [ ! -d \"baremetal-gpt\" ]; then git clone [email protected]:sdbuch/baremetal-gpt; fi \
    && export HYDRA_FULL_ERROR=1 \
    && export WANDB_ENTITY='$WANDB_ENTITY' \
    && export WANDB_API_KEY='$WANDB_API_KEY' \
    && export HF_TOKEN='$HF_TOKEN' \
    && cd baremetal-gpt \
    && git fetch \
    && git checkout -f main \
    && git pull \
    && uv sync --extra tpu \
    && uv run train $@"

gcloud compute tpus tpu-vm ssh "$TPU_NAME" \
  --ssh-flag="$SSH_FLAGS" \
  --command="$COMMANDS" \
  --worker=all

For logging purposes, we need the Weights & Biases default entity and your API key to be exported as environment variables locally. We use Github for code synchronization, and provide the tpu extra to uv when building since we’re running on a TPU VM.

Here’s an example call on the local machine, for data-parallel training:

./deploy/run.sh tpu-v4-32 num_vocab=10 num_layers=4 num_steps=300 lr=3e-4 'mesh_shape=[16]'

For this to work well, when I first create the TPU VM,7 I make sure the following startup script is run on each host in the configuration:

# startup.sh
curl -LsSf https://astral.sh/uv/install.sh | sh  # Download uv
sudo cp /root/.local/bin/uv /usr/local/bin/uv
sudo cp /root/.local/bin/uvx /usr/local/bin/uvx
ssh-keyscan -H github.com >> /etc/ssh/ssh_known_hosts  # Code copying

This installs uv and adds github.com to the known hosts, so that I can connect to Github via ssh (agent forwarding is enabled for authentication). If we don’t add it to the known hosts list, it will ask to confirm connection and the gcloud command will fail.

That’s it! This gives a pretty easy minimal setup for distributed training with JAX on TPU VMs.

Adding Logging

Custom logging is easy to add, as well. We add an Enum to the config to select the logger, then add two basic loggers: one for printing to the command line, and one for using Weights and Biases.

There are two things to keep in mind here:

  1. In a distributed setting, by default we only want one process (say, the one with index 0) to log metrics.8
  2. We want to set up appropriate host-device pipelining, so as we alluded to in the previous post (taking a trick from the JAX training cookbook) we should log the metrics from the previous batch on every step. We can incorporate automatic buffering into our Logger class to do this.
# config.py

# ...
class LoggerType(Enum):
    PRINT = "print"
    WANDB = "wandb"


# ...


# loggers.py
def logger_factory(logger_type: LoggerType):
    match logger_type:
        case LoggerType.PRINT:
            return PrintLogger
        case LoggerType.WANDB:
            return WandbLogger


def get_run_name(base: str) -> str:
    """Prepend current UTC time to config-specified run name"""
    timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
    run_name = f"{timestamp}-{base}"
    return run_name


class Logger:
    """Logger base class. Implements a depth-1 buffer for device-host pipelining"""

    def __init__(self, config: Config):
        self.project_name = config.project_name
        self.run_name = get_run_name(config.run_name)
        if isinstance(config, DictConfig):
            config_dict = OmegaConf.to_container(config, resolve=True)
        else:
            config_dict = asdict(config)
        self.config = config_dict
        self.prev_log_data = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        return False

    def log(self, log_dict: dict):
        pass

    def warn(self, message: str):
        pass

    def buffer(self, log_dict: dict) -> dict | None:
        self.prev_log_data, buffered_metrics = log_dict, self.prev_log_data
        return buffered_metrics

    def flush_buffer(self):
        if self.prev_log_data is not None:
            self.log({})
        self.prev_log_data = None


class PrintLogger(Logger):
    def __init__(self, config: Config):
        super().__init__(config)

    def __enter__(self):
        print(f"Project: {self.project_name}")
        print(f"Run: {self.run_name}")
        return super().__enter__()

    def log(self, log_dict: dict):
        if (buffered_dict := self.buffer(log_dict)) is None:
            return
        print(*[f"{metric}: {val}" for metric, val in buffered_dict.items()], sep="\t")

    def warn(self, message: str):
        print(message)


class WandbLogger(Logger):
    def __init__(self, config: Config):
        super().__init__(config)
        self.is_master = jax.process_index() == 0

    def __enter__(self):
        if self.is_master:
            wandb.init(
                project=self.project_name, name=self.run_name, config=self.config
            )
        return super().__enter__()

    def __exit__(self, exc_type, exc_value, traceback):
        if self.is_master:
            wandb.finish()
        return super().__exit__(exc_type, exc_value, traceback)

    def log(self, log_dict: dict):
        if self.is_master:
            if (buffered_dict := self.buffer(log_dict)) is None:
                return
            wandb.log(buffered_dict)

    def warn(self, message: str):
        if self.is_master:
            print(message)

Then we simply wrap our training loop with this context manager:

# train.py
# ...
Logger = logger_factory(config.logger_type)
# ...
with Logger(config) as logger:
    # ...
    for step, batch in enumerate(batch_iter):
        with jax.set_mesh(mesh):
            metrics, train_state = train_step(config, batch, train_state)
        logger.log(metrics | {"step": step})
        # ...
        if step == config.train_dataset.num_steps - 1:
            break
# ...

Actual Data

Now that we have a reasonable backend, let’s get it working on some text data, and try it at GPT-2 scale! Within the context of the above refactors, we just need to write the data preprocessing code and the dataloader. We also need to add evaluation code – we’ve done this in the Github, but not detailed it above.

We’ll do a minimal demo with Andrej Karpathy’s Tiny Shakespeare dataset, available in the nanoGPT repo. We adapt Karpathy’s script for downloading and tokenizing the dataset, making use of the tiktoken library for this purpose:

# data/download_and_tokenize_tiny_shakespeare.py
import os

import jax.numpy as jnp
import requests
import tiktoken

"""From https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare/prepare.py"""

# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname(__file__), "input.txt")
if not os.path.exists(input_file_path):
    data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    with open(input_file_path, "w", encoding="utf-8") as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, "r", encoding="utf-8") as f:
    data = f.read()
n = len(data)
train_data = data[: int(n * 0.9)]
val_data = data[int(n * 0.9) :]

# encode with tiktoken gpt2 bpe
enc = tiktoken.get_encoding("gpt2")
train_ids = enc.encode_ordinary(train_data)
val_ids = enc.encode_ordinary(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = jnp.array(train_ids, dtype=jnp.int32)
val_ids = jnp.array(val_ids, dtype=jnp.int32)
jnp.save("data/tiny-shakespeare/train.npy", train_ids)
jnp.save("data/tiny-shakespeare/val.npy", val_ids)

# train.bin has 301,966 tokens
# val.bin has 36,059 tokens

In a distributed setting, we run this on each host, with a command like:

# deploy/download_shakespeare.sh

TPU_NAME="$1"
shift
SSH_FLAGS='-A -o ForwardAgent=yes'
COMMANDS="if [ ! -d \"baremetal-gpt\" ]; then git clone [email protected]:sdbuch/baremetal-gpt; fi \
    && export HF_TOKEN='$HF_TOKEN' \
    && cd baremetal-gpt \
    && git fetch \
    && git checkout -f main \
    && git pull \
    && uv sync --extra tpu \
    && uv run data/tiny-shakespeare/download_and_tokenize_tiny_shakespeare.py $@"

gcloud compute tpus tpu-vm ssh "$TPU_NAME" \
  --ssh-flag="$SSH_FLAGS" \
  --command="$COMMANDS" \
  --worker=all

Then the data is ready for loading. We write a python loader for the data to go with our previously-written factor, using the same dataloader_with_replacement as for the synthetic data:

# data.py
def load_shakespeare(config: DatasetConfig):
    path = Path(config.path)
    data = jnp.load(path / (config.split.value + ".npy"))
    return data, jnp.array(0)

Then running the experiment is simply a matter of setting up the right experimental config. We use the following, which runs almost all the evaluation code we wrote (to test it!):

# configs/experiment/tiny-shakespeare.yaml
# @package _global_

defaults:
  - /dataset/shakespeare@train_dataset

val_log_interval: 250

train_dataset:
  num_steps: 3000

val_list:
  - dataset:
      name: SHAKESPEARE
      path: ${train_dataset.path}
      split: VAL
      seq_len: ${train_dataset.seq_len}  # At inference time, this functions like prompt size
      global_batch_size: ${train_dataset.global_batch_size}
      num_steps: 200
    evaluator: NLL
  - dataset:
      name: SHAKESPEARE
      path: ${train_dataset.path}
      split: VAL
      seq_len: 128  # At inference time, this functions like prompt size
      global_batch_size: 16
      use_splash: False
    evaluator: AUTOREGRESSIVE_ROLLOUTS

eval_list:
  - dataset:
      name: SHAKESPEARE
      path: ${train_dataset.path}
      split: VAL
      seq_len: 128  # At inference time, this functions like prompt size
      global_batch_size: 16
      use_splash: False
    evaluator: AUTOREGRESSIVE_ROLLOUTS

model:
  transformer_type: DISCRETE
  max_seq_len: ${train_dataset.seq_len}  # no longer than train
  is_causal: True
  num_vocab: 50257  # gpt2 tokenizer

optimizer:
  weight_decay: 1e-1


inference:
  tokenizer: GPT2
  max_tokens_to_generate: 128

The tokenizer parameters, as suggested, are just used for rollouts at inference time. We use some Hydra interpolation in the config above to avoid retyping the same values. Here’s the overall training code we run (which is pretty short!):

# train.py

# ...
@hydra.main(
    version_base=None,
    config_path=str(Path("configs").absolute().resolve()),
    config_name="base_config",
)
def main(config: Config):
    try:
        # Launch distributed and register configs
        jax.distributed.initialize()
        jax.tree_util.register_static(type(config))
    except RuntimeError:
        # This implies the distributed backend has already been initialized
        pass
    config_post_init(config)
    mesh = mesh_from_config(config)
    Logger = logger_factory(config.logger_type)

    # Randomness
    key = jax.random.key(config.seed)
    key_model, key_train, key_val, key_eval = jax.random.split(key, 4)

    # Data
    batch_iter = get_distributed_batch_iter(
        config, config.train_dataset, key_train, mesh
    )

    # Initialize state, configure forward pass and optimization
    with jax.set_mesh(mesh):
        train_state = init_train_state(key_model, config)
    cache_params = CacheParams(enabled=False, size=0)
    kernel = make_splash_kernel(config, config.train_dataset, 0, mesh)
    spec = model_spec(train_state.params)
    opt_update = opt_update_factory(config.optimizer.type)
    weight_decay_mask = jax.tree.map(lambda _, s: bool(s), train_state.params, spec)

    @partial(jax.jit, donate_argnums=2)
    def train_step(config: Config, batch, train_state: TrainState):
        def loss_fn(params: Transformer):
            inputs, targets = batch
            logits, _ = jax.vmap(
                partial(_transformer, config, kernel, params, cache_params=cache_params)
            )(inputs, train_state.kv_cache)
            logits = logits.astype(config.model.compute_dtype.value)
            logprobs = jax.nn.log_softmax(logits, axis=-1)
            return -jnp.take_along_axis(logprobs, targets[..., None], axis=-1).mean()

        loss, grad = jax.value_and_grad(loss_fn)(train_state.params)
        grad_clipped, _, global_grad_norm = grad_norm_and_clip(config, grad)
        update__opt_state = jax.tree.map(
            partial(opt_update, config),
            train_state.params,
            grad_clipped,
            train_state.opt_state,
            weight_decay_mask,
        )
        # Transpose the output tree to get update tree and state tree
        update, opt_state = map(
            lambda i: jax.tree.map(lambda x, y: y[i], grad, update__opt_state), range(2)
        )
        params = jax.tree.map(lambda x, y: x + y, train_state.params, update)
        new_state = TrainState(
            params=params, opt_state=opt_state, kv_cache=train_state.kv_cache
        )

        metrics = {"batch_loss": loss, "grad_norm": global_grad_norm}
        return metrics, new_state

    # Training loop
    with Logger(config) as logger:
        do_evals = partial(eval_loop, config, mesh=mesh, logger=logger)
        for step, batch in enumerate(batch_iter):
            with jax.set_mesh(mesh):
                metrics, train_state = train_step(config, batch, train_state)
            logger.log(metrics | {"step": step})
            if (step + 1) % config.val_log_interval == 0:
                # Calculate val metrics
                key_val = do_evals(key_val, config.val_list, train_state.params, step)
            if step == config.train_dataset.num_steps - 1:
                break

        # Run evals (testing)
        key_eval = do_evals(key_eval, config.eval_list, train_state.params, step)


def eval_loop(
    config: Config,
    key,
    eval_list: list[EvaluationConfig],
    params: Transformer,
    step: int,
    logger: Logger,
    mesh,
):
    logger.flush_buffer()
    for evaluation in eval_list:
        kernel = make_splash_kernel(config, evaluation.dataset, 0, mesh)
        key, key_d, key_e = jax.random.split(key, 3)
        batch_iter = get_distributed_batch_iter(config, evaluation.dataset, key_d, mesh)
        evaluation_fn = evaluator_factory(evaluation)
        metrics = evaluation_fn(config, key_e, kernel, mesh, params, batch_iter)
        logger.log(metrics | {"step": step})
    logger.flush_buffer()
    return key

Check out the evaluators.py file in the Github for the full structure of the evaluation code.

Checking the Results

The default model we train has 12 layers, embedding dimension 768, 12 heads per attention layer, and a MLP expansion factor of 4. With GPT2 vocabulary size, this ends up at about 160M parameters. Our Tiny Shakespeare config follows nanoGPT, setting the defaults as:

# configs/dataset/shakespeare.yaml
name: SHAKESPEARE
path: "data/tiny-shakespeare"
split: TRAIN
seq_len: 256
global_batch_size: 128

So we process 32K tokens per batch. The total number of tokens in the training set is only 301,966! So without additional regularization, we expect to overfit this dataset very rapidly with our model.

We test this out, via the launch command

./deploy/run.sh tpu-v4-32 +deploy=v4-16 +experiment=tiny-shakespeare

on a 16-chip TPU v4 VM (4 hosts). Following the experiment config we saw above, we train for 3000 steps, and log val metrics every 250 steps.

Batch loss, first 500 steps of Tiny Shakespeare training
Plot of batch loss (negative log-likelihood) versus training step for the first 500 steps of Tiny Shakespeare training. The model overfits the training data.

We see that the model rapidly overfits the training data. It takes about 10 steps per epoch, so the plot shows about 50 epochs worth of training – significantly more than we encounter in standard language model training datasets.

Gradient norm, first 500 steps of Tiny Shakespeare training
Plot of gradient norm versus training step for the first 500 steps of Tiny Shakespeare training. Training is stable, even without learning rate scheduling.

The model trains stably (we are performing gradient clipping to norm 1.0 by default), even without learning rate scheduling. Scheduling would improve the initial stability; since we are heavily overfitting, there does not seem to be any need for LR cooldown at the end of training.

Validation loss, logged every 250 steps, over 3000 steps of training
Plot of validation loss (negative log likelihood) versus training step for the 3000 step Tiny Shakespeare training trajectory. The model is heavily overfit.

Our evaluation code lets us log various metrics, including the validation set negative log likelihood every 250 steps of training. We verify that the model is significantly overfit to the training set – the validation loss actually increases monotonically! Taking a look at some of the trained model’s rollouts verifies that they are indeed not generalizing very well (although they’re quite fun to read):

################################
Prompt:
################################
 father? Gentles, methinks you frown:
And wherefore gaze this goodly company,
As if they saw some wondrous monument,
Some comet or unusual prodigy?

BAPTISTA:
Why, sir, you know this is your wedding-day:
First were we sad, fearing you would not come;
Now sadder, that you come so unprovided.
Fie, doff this habit, shame to your estate,
An eye-sore to our solemn festival!

TRANIO:
And tells us, what occasion of import
Hath all
################################
Generated text:
################################
 the castle wall:
And wheither is come to me, and sit by fortune's side,
The right is not! God's my son,
Immoderate but a little distressed widow;
For joyful wife's son is no weary way.

ROMEO:
What say'st thou, that didst me yet?

MERCUTIO:
Good dream of mine own.

ROMEO:
Give me that mattock and the wrenching iron.
Hold, take this letter; early, and my lord

Till then be urged to my advancement?

Nurse:

The model seems to be regurgitating tokens from Romeo and Juliet, with some random digressions sprinkled in! For these rollouts, we generate up to max_seq_len tokens (based on the training config; for us, it’s 256) off of a 128 token prompt. This avoids out-of-distribution RoPE bugs and KV cache errors.

Conclusions

That’s all for our base infra buildout! Some of the key takeaways include:

  • Correct configuration of splash attention Pallas kernel;
  • Flexible infrastructure for training and inference across different datasets and evaluation types;
  • Basic training improvements (with fairly robust infrastructure, except perhaps for our model spec) that will let us experiment with new optimizers in the future!

As we’ve gone from the notebook-based code in the previous blog to the full repository, we’ve seen along the way that building reasonably general experiment infrastructure requires a bit of a different approach than writing one-off research code. It also requires a decent amount of time! There’s a good chance a lot of this can be automated in a fast-paced environment with LLMs – I have not used LLMs for writing any code in these blogs – but writing the code by hand is hard to beat for learning and for precision of implementation (as Karpathy suggested in his recent conversation with Dwarkesh!). As a small piece of evidence of this, we’ve ended up with a very concise and readable train.py script: not much more than 150 lines!

Some remaining infrastructure-related tweaks we should add in the future:

  • Model checkpointing: we’ll add this at some point soon, likely just using Orbax.
  • Further parallelism configurations, in particular basic FSDP.

Look forward to less infrastructure and closer-to-research experimentation in future iterations of this blog series.

Acknowledgments

Thanks to the TRC program for compute support.

  1. The reason we actually check for None is that we generally need to call our model with different configurations (e.g., when doing autoregressive inference, or when evaluating with a different parallelism configuration), and it is suboptimal to create an entirely different Config object for each of these, or to manually edit the global Config object for these as we run. So instead, we check for None, and just pass this as the kernel parameter to the model as we need to. 

  2. For example, we typically only apply weight decay to matrix-like parameters in the network (e.g., not to biases). 

  3. Note that even this approach is not perfect, as for example the w_out matrix associated to the attention output projection actually does reduce over one of the dimensions marked as non-matrix. We should revisit this for future improvements (i.e., Muon-type optimizers!). 

  4. At least in our setting, where we’re not using multislice operation. I think the distributed backend is just as easy to set up for multislice, though! 

  5. Creating the training/val/test split on each distributed process is probably not a good idea, even though we seed each host’s RNG with the same seed, and generate the splits with the same key. It won’t matter for our purposes here because the synthetic data distribution is so simple (even if every process had a different training split, it would be okay, because there isn’t a material distinction between the training and test set for this distribution). We’ll fix this with proper processing of data when we move to text data. 

  6. We could use Grain for this purpose! 

  7. Here, tpu-v4-32, for a four-host v4 setup: the 32 denotes the number of cores, chips = cores / 2, and the number of chips dictates the mesh shape. 

  8. For more complex debugging, say in interpretability settings, we need to add a bit more custom code in order to have each process log its respective metrics (say, per-device gradients before all reduce, etc.).