SPMD in JAX #3: Infrastructure Buildout
In this third post in a series on programming with JAX for training transformers, we pick up where we left off at the end of the previous post. We’ll extend our previous small transformer model to the NanoGPT scale, incorporating actual text data, distributed training, and a real optimizer, and perform some basic evaluation! We’ll end up with a reasonable substrate for building future experiments on top of.
As in previous posts in the series, we make good use of the Google scaling playbook (Austin et al., 2025) for profiling.
- Package Conversion and Improvements
- Actual Data
- Conclusions
- Acknowledgments
Package Conversion and Improvements
We’re running in Python 3.13, and we’ve upgraded to JAX 0.8.0 from previous posts. Now that we’ve written a good chunk of model, training, and config code, we’ll also be working on top of a packaged version of our code.
The code is hosted in this GitHub repository: baremetal-gpt. There are instructions to install the repository as a package using uv there.
We’ll start by talking through some refactoring that leads us from the state of our code in the previous post to the state of the code in the GitHub repository.
Config Upgrades
To run experiments with our code in a distributed (i.e., multi-host) setting, we’ll benefit from having a robust way to dispatch different experimental configurations from the command line. We’ll use Hydra for this purpose. Hydra is not the most minimal solution, but it’s the most feature-complete library for managing configs that I know of (e.g., specifying lists as arguments), and for our simple purposes we’ll be able to adopt it with only a small amount of added code and complexity.
Small Modifications to our Config Dataclass
Hydra supports registering config dataclasses as so-called “Structured Configs”. This tutorial gives a good description of the core functionalities we need. We only need to make a few small modifications to our previous dataclass:
- Use only supported types: instead of random class objects (e.g., for dtypes or
PartitionSpecs), we need to createEnums for passing values, and disambiguate them at call time. There are other minor caveats (e.g., can’t arbitrarily nest lists; restrictions on use ofOptional). - Make the config mutable (remove
frozen=True). - Re-register the config as static with JAX. Hydra will wrap our dataclass so that it
ends up duck-typed as a
Configobject, but we will need to re-register the wrapped config as static. Since the dataclass is no longer frozen, we can do this with a post-init function, which we’ll now call when we launch our training loop. - Create a function to register the
Configdataclass in a HydraConfigStoreinstance. In general, we can create arbitrarily stratified configs for easier management (e.g., separate configs for model, experiment, data, etc.). We’ll call this function in ourmainmethod. - Refactor the training loop into a
train.pyfile with amainmethod that we’ll call when we want to launch an experiment, and add the necessary Hydra setup code.
We will also make a larger refactor, which will help us later with setting up config YAML files for experiments: we will change from a flat config to a hierarchical config, with different parameters split into different sub-configs (e.g., for the optimizer, for data, …). This is not an objectively superior design – for example, it’s arguably a bit easier to write new code and integrate static type checking with a flat config – but we’ll follow it in order to add a bit more encapsulation to distinct config parameters.
Here’s a representative snipped from the new config.py with the modifications discussed
above. There are many forward-looking modifications too, which we just show
out-of-context selections of.
# config.py
# We omit the sub-configs, and other nitty-gritty details...
# Config is now top-level (orchestration parameters, etc.)
@dataclass(kw_only=True, unsafe_hash=True)
class Config:
"""Overall config class, containing top-level orchestration parameters"""
seed: int = 1337
logger_type: LoggerType = LoggerType.WANDB
project_name: str = "bmgpt-debug"
run_name: str = ""
val_log_interval: int = 1000 # log validation metrics every <this many> batches
train_dataset: DatasetConfig = MISSING
val_list: list[EvaluationConfig] = field(default_factory=list) # validation metrics
eval_list: list[EvaluationConfig] = field(default_factory=list)
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
model: ModelConfig = field(default_factory=ModelConfig)
inference: InferenceConfig = field(default_factory=InferenceConfig)
sharding: ShardingConfig = field(default_factory=ShardingConfig)
def register_configs():
cs = ConfigStore.instance()
cs.store(group="optimizer", name="base_optimizer", node=OptimizerConfig)
cs.store(group="model", name="base_model", node=ModelConfig)
cs.store(group="inference", name="base_inference", node=InferenceConfig)
cs.store(group="sharding", name="base_sharding", node=ShardingConfig)
cs.store(name="base_config", node=Config)
def config_post_init(config: Config):
"""Input validation."""
# ...
And here’s the relevant parts of train.py that we modify. There is some extra code for
distributed initialization, which we discuss in more detail later in the post.
from pathlib import Path
import hydra
from bmgpt.config import Config, config_post_init, register_configs
register_configs()
# ...
@hydra.main(
version_base=None,
config_path=str(Path("configs").absolute().resolve()),
config_name="base_config",
)
def main(config: Config):
try:
# Launch distributed and register configs
jax.distributed.initialize()
jax.tree_util.register_static(type(config))
except RuntimeError:
# This implies the distributed backend has already been initialized
pass
config_post_init(config)
if __name__ == "__main__":
main()
The modifications are quite simple: we register our Config dataclass in the
ConfigStore instance as 'config', then decorate our main method in train.py to
use that appropriate config. The 'config_path' argument in that decorator specifies a directory where we store
different .yaml config files specifying different experimental setups (e.g., different
model architectures, different datasets, different distributed training environments,
etc.). Then we just call our post-init config method right at the start of main: right
now, it just validates some config inputs.
When we want to launch an experiment now, we can run a command like the following (train loop refactoring follows later in the blog):
uv run --extra tpu train model.num_layers=1
Here, we are using uv optional dependencies so that we can run on either CPU or TPU
depending on the platform, as well as a build script in our package (see the
pyproject.toml).
Super easy! Anything that isn’t specified follows our defaults from the Config
dataclass definition.
As the experiments we want to run become more diverse and complex, it becomes harder and
harder to productively specify all default overrides via the command line. To address
this, we can write YAML configs for different experiments as we mentioned above. For
example, we can make a file at configs/experiments/text_toy.yaml, with:
# @package _global_
defaults:
- /dataset/staircase@train_dataset
train_dataset:
num_steps: 100
eval_list:
- dataset:
name: NUMBER_STAIRCASE
path: ""
split: TEST
seq_len: 1 # At inference time, this functions like prompt size
global_batch_size: 16
evaluator: AUTOREGRESSIVE_ROLLOUTS
model:
transformer_type: DISCRETE
is_causal: True
num_vocab: 10
See the Hydra docs for more details about how to parse this file – it has some
boilerplate to ensure it overrides the config defaults we specify in our config.py.
We’ll discuss some of the enums that appear here (you can identify them by the all-caps
text) later on. Then to run an experiment, we can simply override as:
uv run --extra tpu train +experiment=text_toy
In addition, any other command line arguments we want to override with can be added!
Model Improvements and Refactors
We’ll add support for the Pallas Splash Attention kernel, which will help us with model scaling later!
Splash Attention Integration
Splash Attention stands for “sparse Flash Attention” (Pagliardini et al., 2023). Flash Attention is a well-known memory-efficient implementation of the attention operation (Dao et al., 2022); it also is often used to refer to the associated CUDA kernel and surrounding PyTorch package for the algorithm. Compared to flash attention, splash attention adds better support for sparse attention masks, which arise frequently in practice: for example, when we want to pack many relatively short sequences into one long sequence, or precisely mask between document boundaries, we need to make sure not to attend across these boundaries.
The types of optimizations that flash/splash attention entail require a level of control over hardware-level execution that is not possible within the standard JAX-XLA programming pipeline. So JAX exposes a domain-specific language for writing these kernels, called Pallas! The authoring of Pallas kernels is out-of-scope for this post, but we will be interested in learning how to call pre-written Pallas kernels so that we can integrate splash attention.
The interfaces for achieving this are naturally a bit rough around the edges, with scant documentation and relatively few examples for how to proceed. With the help of multiple LLMs, I eventually pieced together a solution based on example code available in the JAX Pallas tests. Here’s a high-level overview, with the code to follow:
- We need to configure the Pallas kernel itself based on how we’ll use it at runtime:
specifically, the sequence length, number of heads, base mask type, and parallelism
configuration. The latter has to be enforced manually with
shard_map. Once we figure out where everything needs to go, this works more-or-less okay with our current model implementation structure, which usesExplicitmesh axes andvmapas much as possible to simplify ‘inner’ model code. However, it currently doesn’t work in JAX0.8.1. It should work in the next release, thanks to a bug report I submitted and a fix from Yash Katariya (thanks!). - Once we have the kernel (represented as a tuple of two
Callables, one corresponding to the actual kernel and one corresponding to a wrapper with theshard_mapconfiguration), we need to thread it through our top-level_transformercall down to the_attncall. A more ‘elegant’ approach to doing this (at least at the call site) might be to use some metaprogramming, but we will go with the simplest solution.
The code for the kernel setup is in splash_helpers.py:
# splash_helpers.py
from functools import partial
import jax
from bmgpt.config import Config, DatasetConfig
from jax.experimental.pallas.ops.tpu.splash_attention import (
BlockSizes,
CausalMask,
MultiHeadMask,
make_splash_mha,
)
# ...
def make_splash_kernel(
config: Config,
config_data: DatasetConfig,
cache_capacity: int,
mesh,
head_shards: int = 1,
q_seq_shards: int = 1,
):
if not config_data.use_splash:
# None ends up calling jax-xla attention
# see _attn
return None
# s is Q len (seq_len @ train; variable/1 at prefill/decode)
# t is K len (s + cache_capacity)
s = config_data.seq_len
t = s + cache_capacity
if config.model.is_causal:
mask = MultiHeadMask(
[
CausalMask(shape=(s, t), offset=cache_capacity)
for _ in range(config.model.num_heads)
]
)
else:
mask = MultiHeadMask(
[FullMask(shape=(s, t)) for _ in range(config.model.num_heads)]
)
BLOCK_SIZE = min(config_data.seq_len, 128)
if config_data.seq_len % 128 != 0 or BLOCK_SIZE % 128 != 0:
# splash attention kernel requires block size to be a multiple of 128
raise NotImplementedError("Splash block size needs to be a multiple of 128")
block_sizes = BlockSizes(
block_q=BLOCK_SIZE,
block_kv=BLOCK_SIZE,
block_kv_compute=BLOCK_SIZE,
block_q_dkv=BLOCK_SIZE,
block_kv_dkv=BLOCK_SIZE,
block_kv_dkv_compute=BLOCK_SIZE,
block_q_dq=BLOCK_SIZE,
block_kv_dq=BLOCK_SIZE,
)
splash_spec = jax.P(None, None)
splash_sharding = jax.sharding.NamedSharding(mesh, splash_spec)
kernel = make_splash_mha(
mask,
head_shards=head_shards,
q_seq_shards=q_seq_shards,
block_sizes=block_sizes,
)
kernel_spec = kernel.manual_sharding_spec(splash_sharding)
@partial(
jax.shard_map,
mesh=mesh,
in_specs=(kernel_spec, splash_spec, splash_spec, splash_spec, jax.P()),
out_specs=splash_spec,
check_vma=False,
)
def splash_sharded(kernel, q, k, v, segment_ids):
return kernel(q, k, v, segment_ids=segment_ids)
return (splash_sharded, kernel)
The function argument cache_capacity specifies
the maximum KV cache size we’ll use (we’ve split this off from max_seq_len which
filled a dual function encompassing this previously).
This allows us to specify a cache size of 0 for training, for example.
Within the function body, we have a new configuration parameter use_splash to denote whether splash
attention should be used or not. For example, we always fall back to XLA attention for
autoregressive inference, since the splash attention Pallas kernel requires the block
size to be a multiple of 128 (and for decode, we almost necessarily have to process a
query block size of 1). In this case, the kernel creation factory just returns None,
and we check for this in _attn.1
The “sparse” part of splash attention comes from the proper use of the segment_ids
argument: distinct documents/sub-sequences can be assigned different segment_ids.
We’ll see an example below, in how to set this up for use during training with respect
to the KV cache.
First, at the level of train.py, we now create our kernels when we set up. For
training, this leads to a simple kernel = make_splash_kernel(config,
config.train_dataset, 0, mesh) statement. We then pass the kernel to any model call we
have: for example, logits, _ = jax.vmap( partial(_transformer, config, kernel, params,
cache_params=cache_params))(inputs, train_state.kv_cache) in the training loop step.
Here, we’ve refactored the way we use the KV cache – a CacheParams class is defined
simply as follows:
class CacheParams(NamedTuple):
enabled: bool
size: int
This way, as we’ll see momentarily, we can easily trace with respect to the size
argument (i.e., for inference) while still omitting all KV cache maintenance code at
compile time if we want to train (i.e., by setting enabled=False).
The kernel argument and this new cache_params argument get threaded down to the level
of _attn, which now reads as follows (abbreviating unchanged components):
# model.py
def _make_causal_mask(seq_len_q: int, seq_len_k: int, cache_size: int) -> Array:
"""(seq_len_q, seq_len_k) bool array with True for past token positions"""
q_positions = cache_size + jnp.arange(seq_len_q)
k_positions = jnp.arange(seq_len_k)
return q_positions[:, None] >= k_positions[None, :]
def _make_cache_mask(seq_len_q: int, seq_len_k: int, cache_size: int) -> Array:
"""(1, seq_len_k) bool array with True for actual cache+context positions"""
k_positions = jnp.arange(seq_len_k)
return k_positions[None, :] < cache_size
def _attn(
config: Config,
kernel,
params: Attn,
x_seq: jax.Array,
kv_cache: jax.Array, # 2 x cache_capacity x n x h
cache_params: CacheParams,
):
# s: sequence length
# d: embedding dim (config.d_model)
# n: attention heads (config.num_heads)
# h: head dim (config.d_head)
# x_seq: s x d
q, k, v = jnp.einsum(
"sd,d3nh->3nsh",
x_seq,
params.w_qkv,
out_sharding=jax.P(*config.sharding.att_qkv),
)
s = q.shape[1]
# skipping rope...
k_cache, v_cache = kv_cache[0], kv_cache[1]
if cache_params.enabled:
k_cache_out = jax.lax.dynamic_update_slice(
k_cache, k, (0, cache_params.size, 0)
)
v_cache_out = jax.lax.dynamic_update_slice(
v_cache, v, (0, cache_params.size, 0)
)
kv_cache_out = jnp.concatenate((k_cache_out[None], v_cache_out[None]), axis=0)
else:
kv_cache_out = kv_cache
# Cache read scheme: to enable same mask for the same s value (Q seq len),
# we concatenate the full cache to K, and mask empty entries
# For efficient training, set cache size zero + cache_params.enabled=False
cache_capacity = k_cache.shape[-2]
k = jnp.concatenate((k_cache, k), axis=1)
v = jnp.concatenate((v_cache, v), axis=1)
# Attention
t = k.shape[1] # t = s + cache_capacity
if kernel:
q_segment_ids = jnp.zeros((s,))
kv_mask = _make_cache_mask(s, t, cache_params.size) | (
~_make_cache_mask(s, t, cache_capacity)
)
kv_mask = ~kv_mask[0]
kv_segment_ids = kv_mask.astype(jnp.int32)
segment_ids = SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
splash_sharded, kernel = kernel
attn_out = splash_sharded(
kernel,
q / config.model.d_head**0.25,
k / config.model.d_head**0.25,
v,
segment_ids,
)
else:
# Make mask
if config.model.is_causal:
mask = _make_causal_mask(s, t, cache_capacity)
cache_mask = _make_cache_mask(s, t, cache_params.size) | (
~_make_cache_mask(s, t, cache_capacity)
)
mask = mask & cache_mask
else:
mask = ~_make_cache_mask(s, t, 0) # full attention
mask = mask[None, ...] # broadcast over heads
# Scale and causal mask
logits = jnp.einsum("nsh,nth->nst", q, k).astype(
config.model.compute_dtype.value
)
logits *= 1.0 / config.model.d_head**0.5
logits = jnp.where(mask, logits, -jnp.inf)
probs = jax.nn.softmax(logits, axis=2) # type: ignore[reportArgumentType]
probs = probs.astype(config.model.param_dtype.value)
attn_out = jnp.einsum("nst,nth->nsh", probs, v)
out = jnp.einsum(
"nsh,hnd->sd",
attn_out,
params.w_o,
out_sharding=jax.P(*config.sharding.res_stream),
)
return out, kv_cache_out
# ... skipping to the cache init function
def init_kv_cache(config: Config, global_batch_size: int, cache_capacity: int):
if not config.sharding.data:
sharding_batch_layer = [None, None]
else:
sharding_batch_layer = config.sharding.data + [None]
sharding = jax.P(*(sharding_batch_layer + config.sharding.att_qkv))
return jnp.zeros(
(
global_batch_size,
config.model.num_layers,
2,
config.model.num_heads,
cache_capacity,
config.model.d_head,
),
dtype=config.model.param_dtype.value,
out_sharding=sharding,
)
The splash Pallas kernel expects the shape to be N x S x H, so we refactor some code
compared to what we had before. The key change here is that in order to repeatedly reuse
the same splash Pallas kernel as we grow the cache, we write the code so that the KV
sequence length is always fixed as the cache capacity plus S (the Q sequence length).
The cache capacity is specified dynamically at initialization, allowing us to use the
same Config object for different training/inference scenarios (e.g., cache size 0 for
training, and some fixed positive integer for autoregressive inference).
This means we need to construct a SegmentIds object that masks out all the unused
portions of the KV cache, which is performed correctly by the above code. Previously, we
were just updating the cache with jax.lax.dynamic_update_slice, rather than
concatenating, but with the splash kernel structure, this would require us to change the
mask after every additional cache update (requiring a new kernel!).
That’s it! We’ve integrated splash attention, allowing us to scale much more effectively to larger model sizes down the road.
Optimizer Improvements
We were using 100% vanilla SGD in the previous post. Let’s upgrade that to a much stronger optimizer, AdamW! (Loshchilov & Hutter, 2017)
In our previous post, we used a very simple tree-mapped update, facilitated by the simple structure of constant-learning-rate SGD:
# ...
loss, grad = jax.value_and_grad(loss_fn)(train_state.params)
new_state = TrainState(
params=jax.tree.map(lambda x, y: x - config.lr * y, train_state.params, grad),
opt=None,
)
# ...
For implementing Adam and AdamW (the “W” denotes “decoupled weight decay”), it’s better
to use a more flexible API: both because the updates become harder to express in a
single concise lambda, and because we generally want to be able to apply different
updates to different parameters.2 Here’s how we’ll do this, at a basic level:
- We’ll define a handfull of optimizer updates that share the same API.
- We’ll allow one update to be selected via the
Configclass (i.e., with theEnumwe mentioned above), and then we’lljax.tree.mapit over the model pytree. - We’ll be able to provide a pytree of arguments to the update, as necessary, allowing to do different things for different parameters.
For example, here’s how this looks for the previous constant-learning-rate SGD update. Since AdamW requires us to save state information (i.e., first-order and second-order momentum buffers), we’ll incorporate that into the API.
# Factory for getting the update, used externally
def opt_update_factory(opt_type: OptType):
match opt_type:
case OptType.ADAMW:
return adamw_update
case OptType.SGD:
return sgd_update
class OptState(NamedTuple):
mu: jax.Array # 1st moment EMA
nu: jax.Array # 2nd moment EMA
step: jax.Array # step number
def init_adam_state(config: Config, param: jax.Array) -> OptState:
return OptState(
mu=jnp.zeros_like(param, dtype=config.optimizer_dtype.value),
nu=jnp.zeros_like(param, dtype=config.optimizer_dtype.value),
step=jnp.array(0, dtype=jnp.int32),
)
def sgd_update(
config: Config, param: jax.Array, grad: jax.Array, state: OptState, wd_mask: bool
):
update = -config.lr * grad
if wd_mask:
# Apply weight decay
update = update - config.lr * config.weight_decay * param
return update, state
The update API takes parameters, gradients, state, and additional flags, and returns the
update direction and the updated state. We add the update direction to the parameters
externally since this is a useful debugging abstraction (and matches the optax API).
Here’s how it looks to apply the optimizer now. It gets a bit more cumbersome to do this with only pytrees, since the update returns a tuple – meaning the output pytree structure is different from the model/grad pytree structure. The latter is a prefix of the former, though, so we can still extract what we need fairly easily.
# In train.py
# ...
class TrainState(NamedTuple):
params: Transformer
opt_state: Any
kv_cache: jax.Array
@jax.jit
def init_train_state(key, config: Config) -> TrainState:
model_params = init_transformer(key, config)
adam_state = jax.tree.map(partial(init_adam_state, config), model_params)
cache = init_kv_cache(config, config.train_dataset.global_batch_size, 0)
return TrainState(params=model_params, opt_state=adam_state, kv_cache=cache)
# ...
def main(config: Config):
# ...
opt_update = opt_update_factory(config.optimizer.type)
# ...
@partial(jax.jit, donate_argnums=2)
def train_step(config: Config, batch, train_state: TrainState):
# ...
loss, grad = jax.value_and_grad(loss_fn)(train_state.params)
# ...
update__opt_state = jax.tree.map(
partial(opt_update, config),
train_state.params,
grad,
train_state.opt_state,
weight_decay_mask,
)
# Transpose the output tree to get update tree and state tree
update, opt_state = map(
lambda i: jax.tree.map(lambda x, y: y[i], grad, update__opt_state),
range(2),
)
params = jax.tree.map(lambda x, y: x + y, train_state.params, update)
new_state = TrainState(
params=params, opt_state=opt_state, kv_cache=train_state.kv_cache
)
# ...
This is pretty straightforward, once we’ve written it!
Now, to add a new optimizer – AdamW in our case – we just need to define the
appropriate opt_update and possibly define new state initialization functions. We’ve
already done the latter above for Adam. Here’s a simple implementation of AdamW that
agrees with the optax implementation (not bitwise):
def adamw_update(
config: Config, param: jax.Array, grad: jax.Array, state: OptState, wd_mask: bool
):
beta1 = config.beta1
beta2 = config.beta2
lr = config.lr
eps = config.eps_adam
weight_decay = config.weight_decay
mu = beta1 * state.mu + (1 - beta1) * grad.astype(config.optimizer_dtype.value)
nu = beta2 * state.nu + (1 - beta2) * grad.astype(config.optimizer_dtype.value) ** 2
new_state = OptState(mu=mu, nu=nu, step=state.step + 1)
mu_debias = mu / (1 - beta1**new_state.step)
nu_debias = nu / (1 - beta2**new_state.step)
update = -lr * mu_debias / (eps + jnp.sqrt(nu_debias))
if wd_mask:
# Apply weight decay
update = update - lr * weight_decay * param
return update.astype(config.param_dtype.value), new_state
The slightly tricky part remaining is how to pass the right weight decay mask pytree.
It’s best to do this with model metadata, since the format we store the parameters in
doesn’t necessarily match how they’re used. For example, we store our attention QKV
parameters as a D x 3 x N x H array, where D is the model dimension, N the number
of attention heads, and H the projected attention head dimension. We use these
parameters as D x H matrices, but we can’t safely infer this from only the shape
information! Similarly, since we vmap our Blocks initialization over the layer
dimension L when constructing the model, our MLP output bias parameters are shaped
like L x D. This is a matrix, so when iterating over the model pytree, we can’t just
check whether the current leaf has ndim > 1!
This is a situation where we clearly see the advantage of the bigger, standard model
libraries (e.g. Equinox): there, metadata and parameters can easily be stored together.
For us, we’d need to depart from our appealingly simple NamedTuple model architecture
to do this (e.g. register custom pytree leaves). So we’ll instead use a slightly
unpleasant approach: we’ll just add a function to our model.py file that returns a
model “spec” associated to the architecture, containing all relevant metadata.3
# in model.py
def model_spec(model: Transformer) -> Any:
# Make the spec (we need some way to pass metadata around)
def _make_spec_from_str(path: str) -> tuple[int, int] | None:
param_str = path[-1].__str__()
matrix_axes_dict = {
".w_qkv": (-4, -1),
".w_o": (-3, -1),
".w_up": (-2, -1),
".w_down": (-2, -1),
".w": (-2, -1),
# ...
}
return matrix_axes_dict.get(param_str, None)
spec = jax.tree.map_with_path(lambda p, _: _make_spec_from_str(p), model)
return spec
The keys in the dictionary used to generate the output pytree correspond to the
parameter names in our Transformer model. We can get these from the pytree using
jax.tree.map_with_path, which gives the pytree structure (easily converted into a
string, like .blocks.norm_attn.gamma) in addition to the leaves.
The returned pytree looks like the following:
Transformer(emb=Embedding(w=(-2, -1)), blocks=Block(norm_attn=LayerNorm(gamma=None, beta=None), attn=Attn(w_qkv=(-4, -1), w_o=(-3, -1)), norm_mlp=LayerNorm(gamma=None, beta=None), mlp=Mlp(w_up=(-2, -1), bias_up=None, w_down=(-2, -1), bias_down=None)), unemb=Unembedding(w=(-2, -1)))
Now, back to the matter at hand: we need to generate a mask pytree that tells us which parameters in the model to apply weight decay to. To apply it only to the matrix parameters in the model, we can do the following:
spec = model_spec(train_state.params)
weight_decay_mask = jax.tree.map(lambda x, s: bool(s), train_state.params, spec)
Since None is not truthy, this will give us the mask we desire! Then we use it as
above, where we already provided weight_decay_mask as an argument to the update
calculation tree map (but didn’t specify how to generate it).
Here’s the entire updated train step. We add gradient clipping to it, since this doesn’t require too much effort:
# in optimizers.py
def grad_norm_and_clip(
config: Config, model: Transformer
) -> tuple[Transformer, Transformer, float]:
# Gradient norms in param precision (NOTE: might want fp32?)
grad_norms_squared = jax.tree.map(lambda grad: jnp.sum(grad**2), model)
global_grad_norm = jax.tree.reduce(operator.add, grad_norms_squared) ** 0.5
truncated_norm = jax.lax.select(
global_grad_norm >= config.clip_grad,
global_grad_norm,
jnp.ones_like(global_grad_norm),
)
return (
jax.tree.map(lambda grad: grad / truncated_norm, model),
grad_norms_squared,
global_grad_norm,
)
It’s quite straightforward overall!
# Initialize state, configure forward pass and optimization
with jax.set_mesh(mesh):
train_state = init_train_state(key_model, config)
cache_params = CacheParams(enabled=False, size=0)
kernel = make_splash_kernel(config, config.train_dataset, 0, mesh)
spec = model_spec(train_state.params)
opt_update = opt_update_factory(config.optimizer.type)
weight_decay_mask = jax.tree.map(lambda _, s: bool(s), train_state.params, spec)
@partial(jax.jit, donate_argnums=2)
def train_step(config: Config, batch, train_state: TrainState):
def loss_fn(params: Transformer):
inputs, targets = batch
logits, _ = jax.vmap(
partial(_transformer, config, kernel, params, cache_params=cache_params)
)(inputs, train_state.kv_cache)
logits = logits.astype(config.model.compute_dtype.value)
logprobs = jax.nn.log_softmax(logits, axis=-1)
return -jnp.take_along_axis(logprobs, targets[..., None], axis=-1).mean()
loss, grad = jax.value_and_grad(loss_fn)(train_state.params)
grad_clipped, _, global_grad_norm = grad_norm_and_clip(config, grad)
update__opt_state = jax.tree.map(
partial(opt_update, config),
train_state.params,
grad_clipped,
train_state.opt_state,
weight_decay_mask,
)
# Transpose the output tree to get update tree and state tree
update, opt_state = map(
lambda i: jax.tree.map(lambda x, y: y[i], grad, update__opt_state), range(2)
)
params = jax.tree.map(lambda x, y: x + y, train_state.params, update)
new_state = TrainState(
params=params, opt_state=opt_state, kv_cache=train_state.kv_cache
)
metrics = {"batch_loss": loss, "grad_norm": global_grad_norm}
return metrics, new_state
Distributed Training
Next, let’s set up distributed training. On the whole, this is actually very easy to do on Google’s TPU VMs!4 This is because:
- The VMs are already configured with the IP addresses / ICI information of the other hosts in the slice.
- We’re using
jax.sharding.AxisType.Explicitsharding, and our model is relatively simple. So communication gets handled by the XLA compiler! We don’t actually need to usepmaporpjitanywhere.
For our purposes, the only part that requires care is data management, particularly data loading. We’ll describe this below for our simple synthetic data distribution, as well as the other changes we need to make for distributed training.
High-Level Paradigm
In JAX distributed training, each host (or “process”) is physically connected to some
number of chips via PCIe. On v4 TPUs, which we’ve been using, each host
is connected to 4 chips (each chip with 2 cores). When working with TPU VMs, we can
think of one host as corresponding to one VM that we can ssh into; in our tests
previously, we’ve been working on a single host.
Different hosts are arranged into “slices”. Within a slice, chips are interconnected via high-bandwidth optical links, letting them communicate relatively efficiently! When it comes to data movement from the host to the device, though, each host can only directly communicate with the chips it is physically connected to. These are called the “addressable” devices of that host.
With jax.distributed, we can programmatically handle these differences while only
writing a single program! Our program gets run identically on each host, but we have
access to certain functions that can allow us to branch and shard differently depending
on which host we are. Key functions that we’ll see below are:
jax.distributed.initialize(), which synchronizes information about all interconnected devices among the different hosts;jax.process_index(), which returns a different integer for each host (allowing branching);jax.Array.addressable_shards,jax.Array.is_fully_addressable(etc.), which can be used to programmatically determine which devices can directly modify a given array. For example, ifxis ajax.Arraythat is fully replicated across devices (withjax.P()as its partition spec), thenx.is_fully_addressableisFalsein a multi-host setting! We want to avoid having our program modify an array in a way that is not compatible with either its sharding or the devices with respect to which it is addressable.
For more detailed information, see Chapter 2 of the TPU scaling book (Austin et al., 2025) as well as relevant sections of the JAX documentation (JAX Team, 2025).
Setting up the Mesh and Initializing the Distributed Backend
One multi-host footgun I learned about the hard way is that setting a global mesh with
jax.set_mesh, as we have been doing so far, makes it hard to create arrays that exist
only on the local process (more on that below). So we switch to creating a mesh from
our config options at the start of the main function, then instantiate it as a context
manager when we need it. This turns out to be pretty easy—we don’t need to pass the
mesh to all parts of our model’s forward pass and add it as an additional parameter to
out_sharding (etc.)!
Here’s the new setup (which we saw part of above):
# in config.py
def mesh_from_config(config: Config):
mesh = jax.make_mesh(
config.sharding.mesh_shape,
config.sharding.mesh_axis_names,
len(config.sharding.mesh_shape) * (jax.sharding.AxisType.Explicit,),
)
return mesh
# in train.py
@hydra.main(
version_base=None,
config_path=str(Path("configs").absolute().resolve()),
config_name="base_config",
)
def main(config: Config):
try:
# Launch distributed and register configs
jax.distributed.initialize()
jax.tree_util.register_static(type(config))
except RuntimeError:
# This implies the distributed backend has already been initialized
pass
config_post_init(config)
mesh = mesh_from_config(config)
# ...
That’s it! We just have to be sure to set the config.mesh_shape parameter
appropriately when we launch a job (see below), as well as all our desired sharding
settings. Our config makes simple data parallel training over the entire list of devices
the default.
Modifications for Data Loading
We refactor the data generation code into data.py. At a high level, here’s what we end
up doing in train.py to set up the data loading:5
# in train.py
# ...
# Randomness
key = jax.random.key(config.seed)
key_model, key_train, key_val, key_eval = jax.random.split(key, 4)
# Data
batch_iter = get_distributed_batch_iter(config, config.train_dataset, key_train, mesh)
# ...
We abstract the previous dataset-specific code into this factory function, which calls the appropriate sub-functions (letting us reuse this for val/eval):
# in data.py
def dataset_dataloader_factory(config: DatasetConfig):
match config.name:
case DatasetName.MNIST:
return (load_mnist(config), dataloader_without_replacement)
case DatasetName.NUMBER_STAIRCASE:
return (
make_number_staircase_data(config),
dataloader_with_replacement,
)
case DatasetName.SHAKESPEARE:
return (
load_shakespeare(config),
dataloader_with_replacement,
)
# Helper to just get the global batch iterator, if we don't need the local data/loader
def get_distributed_batch_iter(
config: Config, dataset_config: DatasetConfig, key, mesh
):
data, dataloader = dataset_dataloader_factory(dataset_config)
return get_dataset_on_device(config, dataloader(key, dataset_config, data), mesh)
Distributed data loading has a large number of potential footguns; the JAX Advanced Guides listing has a useful entry about this (JAX Team, 2025), and see also the article about explicit sharding we used in previous blogs (JAX Team, 2025). In the present setting, we take advantage of the fact that we can load the entire dataset into memory on each host. The following code does this:
def make_number_staircase_data(config: DatasetConfig):
# Simple data sequence!
# Small, so mega replicate for ease
text = "012345678987654321" * 1024
ids = [int(c) for c in text]
data = jnp.array(ids, dtype=jnp.int32)
Xtr, Xdev, Xte = split_data(data, 0.8, 0.1)
match config.split:
case SplitType.TRAIN:
return Xtr, jnp.array(0)
case SplitType.TEST:
return Xte, jnp.array(0)
case SplitType.VAL:
return Xdev, jnp.array(0)
def split_data(data: Array, train_fraction: float, dev_fraction: float):
num_data = len(data)
Xtr = data[: int(train_fraction * num_data)]
Xdev = data[
int(train_fraction * num_data) : int((train_fraction + dev_fraction) * num_data)
]
Xte = data[int((train_fraction + dev_fraction) * num_data) :]
return Xtr, Xdev, Xte
The tricky part here is to note that as configured, the above code generates the data
array so that it is locally addressable by every process: each host creates its own
(identical) data array. More precisely, if we evaluate data.is_locally_addressable,
it will be True. This must be contrasted with the case of a fully replicated data
array, which we don’t want here—our approach will be to piece together a global batch
out of local batches on each device, and having a locally addressable data array is
(currently) the canonical way to do this!
Next, we need to make sure each host has access to an independent and identically
distributed stream of batches. This gets handled by the get_dataset_on_device and
dataloader functions.
# in data.py
# Simple next-token prediction dataloader for text training:
# - Each host has the whole dataset (data input)
# - Random sampling without replacement to draw batches (consumes key)
# - Targets are the next token (expect data to be (num_data,) shape)
def dataloader_with_replacement(
key, config: DatasetConfig, data: tuple[Array, Array]
) -> DataloaderOutputType:
inputs, _ = data
num_data = len(inputs)
key = jax.random.fold_in(key, jax.process_index())
for step in it.count():
key = jax.random.fold_in(key, step)
offsets = jax.random.randint(
key,
(config.global_batch_size // jax.process_count(),),
0,
num_data - config.seq_len - 1,
)
seqs = inputs.at[offsets[:, None] + jnp.arange(config.seq_len)].get()
targets = inputs.at[1 + offsets[:, None] + jnp.arange(config.seq_len)].get()
yield seqs, targets
# Helper to map a dataloader with make_array_from_process_local_data
def get_dataset_on_device(
config: Config, dataloader: DataloaderOutputType, mesh: Mesh
) -> DataloaderOutputType:
return map(
lambda batch: jax.make_array_from_process_local_data(
NamedSharding(mesh, jax.P(*config.sharding.data)), batch
),
dataloader,
)
We must pass the mesh we created earlier here in order to correctly create the
global batch out of the local shards that the dataloader generates. Explicit sharding
seems to propagate to outputs, so the data.at operations return locally-addressable
batch arrays on each process, which make_array_from_process_local_data pieces
together into the global array. To generate independent local batches on each process,
we are sure to fold_in the process index to the RNG in order to create a distinct
stream of random offsets.
This quick-and-dirty data loader, which does not shuffle nor guarantee sampling without replacement for batches across processes, is fine for our purposes on this synthetic dataset. It should also be okay for use with much larger datasets, when training for a relatively small number of tokens. But it’s worth keeping in mind that something more robust would be needed in general.6
Remaining Cleanup
All that’s left is to judiciously insert our mesh context manager at key points of the training loop. We saw above (when discussing optimizers) that we need to add it to the model initialization call. Other than that, we just need to add it to the model call in our training loop:
# in train.py
# ...
for step in range(config.num_steps):
batch = next(batch_iter)
with jax.set_mesh(mesh):
cur_metrics, train_state = train_step(config, batch, train_state)
# ...
For simplicity’s sake, we also wrap our sampling code with it.
Launching Everything
Given our DIY ethos, we handle launching distributed jobs on our TPU VM with an ad-hoc launch script:
# run.sh
TPU_NAME="$1"
shift
SSH_FLAGS='-A -o ForwardAgent=yes'
COMMANDS="if [ ! -d \"baremetal-gpt\" ]; then git clone [email protected]:sdbuch/baremetal-gpt; fi \
&& export HYDRA_FULL_ERROR=1 \
&& export WANDB_ENTITY='$WANDB_ENTITY' \
&& export WANDB_API_KEY='$WANDB_API_KEY' \
&& export HF_TOKEN='$HF_TOKEN' \
&& cd baremetal-gpt \
&& git fetch \
&& git checkout -f main \
&& git pull \
&& uv sync --extra tpu \
&& uv run train $@"
gcloud compute tpus tpu-vm ssh "$TPU_NAME" \
--ssh-flag="$SSH_FLAGS" \
--command="$COMMANDS" \
--worker=all
For logging purposes, we need the Weights & Biases default entity and your API key to be
exported as environment variables locally. We use Github for code synchronization, and
provide the tpu extra to uv when building since we’re running on a TPU VM.
Here’s an example call on the local machine, for data-parallel training:
./deploy/run.sh tpu-v4-32 num_vocab=10 num_layers=4 num_steps=300 lr=3e-4 'mesh_shape=[16]'
For this to work well, when I first create the TPU VM,7 I make sure the following startup script is run on each host in the configuration:
# startup.sh
curl -LsSf https://astral.sh/uv/install.sh | sh # Download uv
sudo cp /root/.local/bin/uv /usr/local/bin/uv
sudo cp /root/.local/bin/uvx /usr/local/bin/uvx
ssh-keyscan -H github.com >> /etc/ssh/ssh_known_hosts # Code copying
This installs uv and adds github.com to the known hosts, so that I can connect to
Github via ssh (agent forwarding is enabled for authentication). If we don’t add it to
the known hosts list, it will ask to confirm connection and the gcloud command will
fail.
That’s it! This gives a pretty easy minimal setup for distributed training with JAX on TPU VMs.
Adding Logging
Custom logging is easy to add, as well. We add an Enum to the config to select the
logger, then add two basic loggers: one for printing to the command line, and one for
using Weights and Biases.
There are two things to keep in mind here:
- In a distributed setting, by default we only want one process (say, the one with
index
0) to log metrics.8 - We want to set up appropriate host-device pipelining, so as we alluded to in the
previous post (taking a trick from the JAX training cookbook) we should log the
metrics from the previous batch on every step. We can incorporate automatic
buffering into our
Loggerclass to do this.
# config.py
# ...
class LoggerType(Enum):
PRINT = "print"
WANDB = "wandb"
# ...
# loggers.py
def logger_factory(logger_type: LoggerType):
match logger_type:
case LoggerType.PRINT:
return PrintLogger
case LoggerType.WANDB:
return WandbLogger
def get_run_name(base: str) -> str:
"""Prepend current UTC time to config-specified run name"""
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
run_name = f"{timestamp}-{base}"
return run_name
class Logger:
"""Logger base class. Implements a depth-1 buffer for device-host pipelining"""
def __init__(self, config: Config):
self.project_name = config.project_name
self.run_name = get_run_name(config.run_name)
if isinstance(config, DictConfig):
config_dict = OmegaConf.to_container(config, resolve=True)
else:
config_dict = asdict(config)
self.config = config_dict
self.prev_log_data = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return False
def log(self, log_dict: dict):
pass
def warn(self, message: str):
pass
def buffer(self, log_dict: dict) -> dict | None:
self.prev_log_data, buffered_metrics = log_dict, self.prev_log_data
return buffered_metrics
def flush_buffer(self):
if self.prev_log_data is not None:
self.log({})
self.prev_log_data = None
class PrintLogger(Logger):
def __init__(self, config: Config):
super().__init__(config)
def __enter__(self):
print(f"Project: {self.project_name}")
print(f"Run: {self.run_name}")
return super().__enter__()
def log(self, log_dict: dict):
if (buffered_dict := self.buffer(log_dict)) is None:
return
print(*[f"{metric}: {val}" for metric, val in buffered_dict.items()], sep="\t")
def warn(self, message: str):
print(message)
class WandbLogger(Logger):
def __init__(self, config: Config):
super().__init__(config)
self.is_master = jax.process_index() == 0
def __enter__(self):
if self.is_master:
wandb.init(
project=self.project_name, name=self.run_name, config=self.config
)
return super().__enter__()
def __exit__(self, exc_type, exc_value, traceback):
if self.is_master:
wandb.finish()
return super().__exit__(exc_type, exc_value, traceback)
def log(self, log_dict: dict):
if self.is_master:
if (buffered_dict := self.buffer(log_dict)) is None:
return
wandb.log(buffered_dict)
def warn(self, message: str):
if self.is_master:
print(message)
Then we simply wrap our training loop with this context manager:
# train.py
# ...
Logger = logger_factory(config.logger_type)
# ...
with Logger(config) as logger:
# ...
for step, batch in enumerate(batch_iter):
with jax.set_mesh(mesh):
metrics, train_state = train_step(config, batch, train_state)
logger.log(metrics | {"step": step})
# ...
if step == config.train_dataset.num_steps - 1:
break
# ...
Actual Data
Now that we have a reasonable backend, let’s get it working on some text data, and try it at GPT-2 scale! Within the context of the above refactors, we just need to write the data preprocessing code and the dataloader. We also need to add evaluation code – we’ve done this in the Github, but not detailed it above.
We’ll do a minimal demo with Andrej Karpathy’s Tiny Shakespeare dataset, available in
the nanoGPT repo. We adapt Karpathy’s script for downloading and tokenizing the dataset,
making use of the tiktoken library for this purpose:
# data/download_and_tokenize_tiny_shakespeare.py
import os
import jax.numpy as jnp
import requests
import tiktoken
"""From https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare/prepare.py"""
# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname(__file__), "input.txt")
if not os.path.exists(input_file_path):
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
with open(input_file_path, "w", encoding="utf-8") as f:
f.write(requests.get(data_url).text)
with open(input_file_path, "r", encoding="utf-8") as f:
data = f.read()
n = len(data)
train_data = data[: int(n * 0.9)]
val_data = data[int(n * 0.9) :]
# encode with tiktoken gpt2 bpe
enc = tiktoken.get_encoding("gpt2")
train_ids = enc.encode_ordinary(train_data)
val_ids = enc.encode_ordinary(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")
# export to bin files
train_ids = jnp.array(train_ids, dtype=jnp.int32)
val_ids = jnp.array(val_ids, dtype=jnp.int32)
jnp.save("data/tiny-shakespeare/train.npy", train_ids)
jnp.save("data/tiny-shakespeare/val.npy", val_ids)
# train.bin has 301,966 tokens
# val.bin has 36,059 tokens
In a distributed setting, we run this on each host, with a command like:
# deploy/download_shakespeare.sh
TPU_NAME="$1"
shift
SSH_FLAGS='-A -o ForwardAgent=yes'
COMMANDS="if [ ! -d \"baremetal-gpt\" ]; then git clone [email protected]:sdbuch/baremetal-gpt; fi \
&& export HF_TOKEN='$HF_TOKEN' \
&& cd baremetal-gpt \
&& git fetch \
&& git checkout -f main \
&& git pull \
&& uv sync --extra tpu \
&& uv run data/tiny-shakespeare/download_and_tokenize_tiny_shakespeare.py $@"
gcloud compute tpus tpu-vm ssh "$TPU_NAME" \
--ssh-flag="$SSH_FLAGS" \
--command="$COMMANDS" \
--worker=all
Then the data is ready for loading. We write a python loader for the data to go with our
previously-written factor, using the same dataloader_with_replacement as for the
synthetic data:
# data.py
def load_shakespeare(config: DatasetConfig):
path = Path(config.path)
data = jnp.load(path / (config.split.value + ".npy"))
return data, jnp.array(0)
Then running the experiment is simply a matter of setting up the right experimental config. We use the following, which runs almost all the evaluation code we wrote (to test it!):
# configs/experiment/tiny-shakespeare.yaml
# @package _global_
defaults:
- /dataset/shakespeare@train_dataset
val_log_interval: 250
train_dataset:
num_steps: 3000
val_list:
- dataset:
name: SHAKESPEARE
path: ${train_dataset.path}
split: VAL
seq_len: ${train_dataset.seq_len} # At inference time, this functions like prompt size
global_batch_size: ${train_dataset.global_batch_size}
num_steps: 200
evaluator: NLL
- dataset:
name: SHAKESPEARE
path: ${train_dataset.path}
split: VAL
seq_len: 128 # At inference time, this functions like prompt size
global_batch_size: 16
use_splash: False
evaluator: AUTOREGRESSIVE_ROLLOUTS
eval_list:
- dataset:
name: SHAKESPEARE
path: ${train_dataset.path}
split: VAL
seq_len: 128 # At inference time, this functions like prompt size
global_batch_size: 16
use_splash: False
evaluator: AUTOREGRESSIVE_ROLLOUTS
model:
transformer_type: DISCRETE
max_seq_len: ${train_dataset.seq_len} # no longer than train
is_causal: True
num_vocab: 50257 # gpt2 tokenizer
optimizer:
weight_decay: 1e-1
inference:
tokenizer: GPT2
max_tokens_to_generate: 128
The tokenizer parameters, as suggested, are just used for rollouts at inference time. We use some Hydra interpolation in the config above to avoid retyping the same values. Here’s the overall training code we run (which is pretty short!):
# train.py
# ...
@hydra.main(
version_base=None,
config_path=str(Path("configs").absolute().resolve()),
config_name="base_config",
)
def main(config: Config):
try:
# Launch distributed and register configs
jax.distributed.initialize()
jax.tree_util.register_static(type(config))
except RuntimeError:
# This implies the distributed backend has already been initialized
pass
config_post_init(config)
mesh = mesh_from_config(config)
Logger = logger_factory(config.logger_type)
# Randomness
key = jax.random.key(config.seed)
key_model, key_train, key_val, key_eval = jax.random.split(key, 4)
# Data
batch_iter = get_distributed_batch_iter(
config, config.train_dataset, key_train, mesh
)
# Initialize state, configure forward pass and optimization
with jax.set_mesh(mesh):
train_state = init_train_state(key_model, config)
cache_params = CacheParams(enabled=False, size=0)
kernel = make_splash_kernel(config, config.train_dataset, 0, mesh)
spec = model_spec(train_state.params)
opt_update = opt_update_factory(config.optimizer.type)
weight_decay_mask = jax.tree.map(lambda _, s: bool(s), train_state.params, spec)
@partial(jax.jit, donate_argnums=2)
def train_step(config: Config, batch, train_state: TrainState):
def loss_fn(params: Transformer):
inputs, targets = batch
logits, _ = jax.vmap(
partial(_transformer, config, kernel, params, cache_params=cache_params)
)(inputs, train_state.kv_cache)
logits = logits.astype(config.model.compute_dtype.value)
logprobs = jax.nn.log_softmax(logits, axis=-1)
return -jnp.take_along_axis(logprobs, targets[..., None], axis=-1).mean()
loss, grad = jax.value_and_grad(loss_fn)(train_state.params)
grad_clipped, _, global_grad_norm = grad_norm_and_clip(config, grad)
update__opt_state = jax.tree.map(
partial(opt_update, config),
train_state.params,
grad_clipped,
train_state.opt_state,
weight_decay_mask,
)
# Transpose the output tree to get update tree and state tree
update, opt_state = map(
lambda i: jax.tree.map(lambda x, y: y[i], grad, update__opt_state), range(2)
)
params = jax.tree.map(lambda x, y: x + y, train_state.params, update)
new_state = TrainState(
params=params, opt_state=opt_state, kv_cache=train_state.kv_cache
)
metrics = {"batch_loss": loss, "grad_norm": global_grad_norm}
return metrics, new_state
# Training loop
with Logger(config) as logger:
do_evals = partial(eval_loop, config, mesh=mesh, logger=logger)
for step, batch in enumerate(batch_iter):
with jax.set_mesh(mesh):
metrics, train_state = train_step(config, batch, train_state)
logger.log(metrics | {"step": step})
if (step + 1) % config.val_log_interval == 0:
# Calculate val metrics
key_val = do_evals(key_val, config.val_list, train_state.params, step)
if step == config.train_dataset.num_steps - 1:
break
# Run evals (testing)
key_eval = do_evals(key_eval, config.eval_list, train_state.params, step)
def eval_loop(
config: Config,
key,
eval_list: list[EvaluationConfig],
params: Transformer,
step: int,
logger: Logger,
mesh,
):
logger.flush_buffer()
for evaluation in eval_list:
kernel = make_splash_kernel(config, evaluation.dataset, 0, mesh)
key, key_d, key_e = jax.random.split(key, 3)
batch_iter = get_distributed_batch_iter(config, evaluation.dataset, key_d, mesh)
evaluation_fn = evaluator_factory(evaluation)
metrics = evaluation_fn(config, key_e, kernel, mesh, params, batch_iter)
logger.log(metrics | {"step": step})
logger.flush_buffer()
return key
Check out the evaluators.py file in the Github for the full structure of the evaluation code.
Checking the Results
The default model we train has 12 layers, embedding dimension 768, 12 heads per attention layer, and a MLP expansion factor of 4. With GPT2 vocabulary size, this ends up at about 160M parameters. Our Tiny Shakespeare config follows nanoGPT, setting the defaults as:
# configs/dataset/shakespeare.yaml
name: SHAKESPEARE
path: "data/tiny-shakespeare"
split: TRAIN
seq_len: 256
global_batch_size: 128
So we process 32K tokens per batch. The total number of tokens in the training set is only 301,966! So without additional regularization, we expect to overfit this dataset very rapidly with our model.
We test this out, via the launch command
./deploy/run.sh tpu-v4-32 +deploy=v4-16 +experiment=tiny-shakespeare
on a 16-chip TPU v4 VM (4 hosts). Following the experiment config we saw above, we train for 3000 steps, and log val metrics every 250 steps.
We see that the model rapidly overfits the training data. It takes about 10 steps per epoch, so the plot shows about 50 epochs worth of training – significantly more than we encounter in standard language model training datasets.
The model trains stably (we are performing gradient clipping to norm 1.0 by default), even without learning rate scheduling. Scheduling would improve the initial stability; since we are heavily overfitting, there does not seem to be any need for LR cooldown at the end of training.
Our evaluation code lets us log various metrics, including the validation set negative log likelihood every 250 steps of training. We verify that the model is significantly overfit to the training set – the validation loss actually increases monotonically! Taking a look at some of the trained model’s rollouts verifies that they are indeed not generalizing very well (although they’re quite fun to read):
################################
Prompt:
################################
father? Gentles, methinks you frown:
And wherefore gaze this goodly company,
As if they saw some wondrous monument,
Some comet or unusual prodigy?
BAPTISTA:
Why, sir, you know this is your wedding-day:
First were we sad, fearing you would not come;
Now sadder, that you come so unprovided.
Fie, doff this habit, shame to your estate,
An eye-sore to our solemn festival!
TRANIO:
And tells us, what occasion of import
Hath all
################################
Generated text:
################################
the castle wall:
And wheither is come to me, and sit by fortune's side,
The right is not! God's my son,
Immoderate but a little distressed widow;
For joyful wife's son is no weary way.
ROMEO:
What say'st thou, that didst me yet?
MERCUTIO:
Good dream of mine own.
ROMEO:
Give me that mattock and the wrenching iron.
Hold, take this letter; early, and my lord
Till then be urged to my advancement?
Nurse:
The model seems to be regurgitating tokens from Romeo and Juliet, with some random
digressions sprinkled in!
For these rollouts, we generate up to max_seq_len tokens (based on the training config; for us,
it’s 256) off of a 128 token prompt. This avoids out-of-distribution RoPE bugs and KV
cache errors.
Conclusions
That’s all for our base infra buildout! Some of the key takeaways include:
- Correct configuration of splash attention Pallas kernel;
- Flexible infrastructure for training and inference across different datasets and evaluation types;
- Basic training improvements (with fairly robust infrastructure, except perhaps for our model spec) that will let us experiment with new optimizers in the future!
As we’ve gone from the notebook-based code in the previous blog to the full repository,
we’ve seen along the way that building reasonably general experiment infrastructure
requires a bit of a different approach than writing one-off research code. It also
requires a decent amount of time! There’s a good chance a lot of this can be automated
in a fast-paced environment with LLMs – I have not used LLMs for writing any code in
these blogs – but writing the code by hand is hard to beat for learning and for
precision of implementation (as Karpathy suggested in his recent conversation with
Dwarkesh!). As a small piece of evidence
of this, we’ve ended up with a very concise and readable train.py script: not much
more than 150 lines!
Some remaining infrastructure-related tweaks we should add in the future:
- Model checkpointing: we’ll add this at some point soon, likely just using Orbax.
- Further parallelism configurations, in particular basic FSDP.
Look forward to less infrastructure and closer-to-research experimentation in future iterations of this blog series.
Acknowledgments
Thanks to the TRC program for compute support.
-
The reason we actually check for
Noneis that we generally need to call our model with different configurations (e.g., when doing autoregressive inference, or when evaluating with a different parallelism configuration), and it is suboptimal to create an entirely differentConfigobject for each of these, or to manually edit the globalConfigobject for these as we run. So instead, we check forNone, and just pass this as thekernelparameter to the model as we need to. ↩ -
For example, we typically only apply weight decay to matrix-like parameters in the network (e.g., not to biases). ↩
-
Note that even this approach is not perfect, as for example the
w_outmatrix associated to the attention output projection actually does reduce over one of the dimensions marked as non-matrix. We should revisit this for future improvements (i.e., Muon-type optimizers!). ↩ -
At least in our setting, where we’re not using multislice operation. I think the distributed backend is just as easy to set up for multislice, though! ↩
-
Creating the training/val/test split on each distributed process is probably not a good idea, even though we seed each host’s RNG with the same seed, and generate the splits with the same key. It won’t matter for our purposes here because the synthetic data distribution is so simple (even if every process had a different training split, it would be okay, because there isn’t a material distinction between the training and test set for this distribution). We’ll fix this with proper processing of data when we move to text data. ↩
-
Here,
tpu-v4-32, for a four-hostv4setup: the32denotes the number of cores,chips = cores / 2, and the number of chips dictates the mesh shape. ↩ -
For more complex debugging, say in interpretability settings, we need to add a bit more custom code in order to have each process log its respective metrics (say, per-device gradients before all reduce, etc.). ↩
References
- (2025). Explicit sharding (a.k.a. “sharding in types”). Retrieved from https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html. (Accessed 2025-08-26)
- (2025). Distributed data loading. Retrieved from https://docs.jax.dev/en/latest/distributed_data_loading.html. (Accessed 2025-11-09)
- (2025). Introduction to multi-controller JAX (aka multi-process/multi-host JAX). Retrieved from https://docs.jax.dev/en/latest/multi_process.html. (Accessed 2025-11-09)