This post will be the first in a series on programming with JAX, for training models like transformers. I’m experimenting with these as an alternative to my usual scratch paper or LaTeX notes while learning, in the hope that it will help me with recall and perhaps be useful to others learning this material.

The focus in this post is on building a mental model (mathematical) for sharding and communication based on linear algebra, in particular block matrices, and on studying some low-level code for different communication primitives. The notes are based on two tutorials on sharding: (Austin et al., 2025) and (JAX Team, 2025), as well as some ‘original research’.

Setup

Here’s the environment we will be working in—using Python 3.13 and JAX 0.7.1, on a single TPU v4 host.

from functools import partial

import jax
import jax.numpy as jnp

jax.devices()
Output:
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

Sharding

When we want to perform a high-throughput computation involving some data and some mathematical operations (e.g., passing data through layers of a neural network) using multiple hardware accelerators (e.g., TPUs), we need to, at a minimum:

  1. Have a scheme for indexing hardware accelerators relative to their physical (spatial) layout.
  2. Specify what data and what parameters (inputs) go on which accelerators.
  3. Specify where the outputs of the computation should go.

In JAX, these three steps are abstracted into creation of a Mesh specifying device layout and a PartitionSpec specifying how data is split across devices relative to this mesh. Such a splitting is called a sharding.

Here’s how the first two steps look in code:

mesh = jax.make_mesh((2, 2), ("x", "y"))  # 1x v4 tpu host
sharding = jax.NamedSharding(mesh, jax.P("x", "y"))

A = jnp.zeros((1024, 1024), device=sharding)
jax.debug.visualize_array_sharding(A)
Output:
┌──────────┬──────────┐
│          │          │
│  TPU 0   │  TPU 1   │
│          │          │
│          │          │
├──────────┼──────────┤
│          │          │
│  TPU 2   │  TPU 3   │
│          │          │
│          │          │
└──────────┴──────────┘

Above,

  • The Mesh indexes devices based on spatial axes. JAX will optimize this indexing based on the actual device topology.
  • The NamedSharding creates a sharding specification relative to the mesh, which is defined by the PartitionSpec we provide. Here, we specify that a 2D array has its first dimension sharded along the x dimension of the mesh and its second dimension along y.
  • For the PartitionSpec, we give None for dimensions that should not be sharded. That means this dimension will be present on all devices.
  • If a mesh axis does not appear in the sharding, then the sharded array is fully replicated over that axis. In the extreme case, a NamedSharding of all Nones puts the full array on every device.
  • The compiler will complain about invalid shardings.

We can also move data around after creation on, say, one host.

key = jax.random.key(42)
B = jax.random.normal(key, (1024, 1024))
B_sharded = jax.device_put(B, sharding)
jax.debug.visualize_array_sharding(B)
jax.debug.visualize_array_sharding(B_sharded)
Output:
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│         TPU 0         │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘
┌──────────┬──────────┐
│          │          │
│  TPU 0   │  TPU 1   │
│          │          │
│          │          │
├──────────┼──────────┤
│          │          │
│  TPU 2   │  TPU 3   │
│          │          │
│          │          │
└──────────┴──────────┘

In addition, we can ‘combine’ physical device axes into one or more aggregated axes. This is a common operation, for example in parallelizing a batch axis across all devices. Since the Mesh is our link between partition indices and physical devices, we do this through the mesh:

mesh_flat = jax.make_mesh((4,), ("xy",))
sharding_flat = jax.NamedSharding(mesh_flat, jax.P("xy", None))
C = jnp.zeros((1024, 1024), device=sharding_flat)
jax.debug.visualize_array_sharding(C)
Output:
┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 3         │
└───────────────────────┘

When we perform computations on data, we can either let the JAX compiler infer the output sharding, or we can specify it ourselves.

@jax.jit
def f_contract(x):
    return x.sum(axis=0)


@partial(
    jax.jit,
    out_shardings=jax.NamedSharding(
        mesh,
        jax.P(
            None,
        ),
    ),
)
def f_contract_replicated(x):
    return x.sum(axis=0)


result = f_contract(B_sharded)
result_replicated = f_contract_replicated(B_sharded)
print("Inferred sharding:")
jax.debug.visualize_array_sharding(result)
print("Manual sharding (replicated):")
jax.debug.visualize_array_sharding(result_replicated)
Output:
Inferred sharding:
┌───────┬───────┐
│TPU 0,2│TPU 1,3│
└───────┴───────┘
Manual sharding (replicated):
┌───────────┐
│TPU 0,1,2,3│
└───────────┘

Passing a NamedSharding to jax.jit’s out_shardings keyword argument lets us specify here that the result of the computation should be propagated to all devices. This entails some communication!

Given a valid input-output sharding for a more complex computation, the JAX compiler can heuristically optimize the inner bits of the computation (intermediate computations’ shardings, and necessary communication) to try to make it run as efficiently as possible. In cases where this doesn’t work, JAX has ways to hint at the compiler, as well as more advanced programming paradigms, to explicitly specify intermediate sharding and communications.

Valid Shardings

Given a mesh with $M$ axes, and an array with $N$ axes, we can enumerate all valid shardings as follows. First note:

  • A sharding corresponds to, for each of the $N$ array axes, the selection of one of the $M$ mesh axes, or None.
  • Each mesh axis can appear at most once in a valid sharding.

Then for each $k = 0, \dots, \min \set{M, N}$, pick $\binom{M}{k}$ of the $M$ mesh axes and $\binom{N}{k}$ of the $N$ array axes, then consider all possible permutations of those $k$ selected axes. This is a total of

\[\sum_{k=0}^{\min \set{M, N}} \binom{M}{k} \binom{N}{k} k!\]

possible shardings. This is actually not too large—roughly on the order of $N^M$ when $N$ is large.1

Example: Matrix Multiplication

The most common operations in neural networks are related to matrix multiplications, and so thinking in terms of sharding data with two axes goes a long way. Say we have matrices $\vA \in \bbR^{m \times n}$ and $\vB \in \bbR^{n \times p}$. A sharding of a 2D array (matrix) corresponds to two things: a partitioning of the matrix into sub-matrices, and a splitting of those sub-matrices across devices. In JAX, these two operations are coupled through the Mesh/NamedSharding interface, as we described above:

  • If an array dimension is sharded across a mesh axis, that dimension is partitioned, and the partitioned components will be split across that mesh axis.
  • If an array dimension is not sharded, that dimension is not partitioned, and it appears across all devices.
  • If a mesh axis does not appear in a sharding specification, the sharded data is copied across that mesh axis.

Let’s consider the special case of a 2D mesh, in particular $2 \times 2$. For simplicity, let’s place the first mesh axis in correspondence with the first array dimension, and the second mesh axis in correspondence with the second array dimension.2 Then every partitioning of $\vA$ or $\vB$, for example

\[\begin{equation}\label{eq:partitioning-0} \begin{bmatrix} \vA_{00} & \vA_{01} \\ \vA_{10} & \vA_{11} \end{bmatrix}, \end{equation}\]

corresponds to an arrangement of the matrix across devices. In the above example, we have $\vA_{ij} \in \bbR^{(m/2) \times (n/2)}$, and the coordinates $(i, j)$ of the submatrix correspond to the coordinates of the device it is stored on. For a different partitioning, for example

\[\begin{bmatrix} \vA_{0} \\ \vA_{1} \end{bmatrix},\]

we need to take copying into account: here we have $\vA_{i} \in \bbR^{(m/2) \times n}$, and the physical device layout is

\[\begin{equation}\label{eq:partitioning-1-dev} \begin{bmatrix} \vA_{0} & \vA_{0} \\ \vA_{1} & \vA_{1} \end{bmatrix}. \end{equation}\]

Notice that $\eqref{eq:partitioning-0}$ and $\eqref{eq:partitioning-1-dev}$ are incomparable: the second one uses twice as much memory! But there is a simple ‘algorithm’ to go from a partitioning to a device layout: if an axis is not partitioned, simply copy the array along that axis to reach the full mesh size (and remember that this leads to additional memory usage).

Now, what happens if we want to compute the product $\vA \vB$ when $\vA$ is sharded in some way, for example as in $\eqref{eq:partitioning-0}$? First, thinking abstractly: by properties of block matrix multiplication, we have

\[\vA \vB = \begin{bmatrix} \vA_{11} & \vA_{12} \\ \vA_{21} & \vA_{22} \end{bmatrix} \begin{bmatrix} \vB_{11} & \vB_{12} \\ \vB_{21} & \vB_{22} \end{bmatrix} = \begin{bmatrix} \vA_{11}\vB_{11} + \vA_{12} \vB_{21} & \vA_{11}\vB_{12} + \vA_{12} \vB_{22} \\ \vA_{21}\vB_{11} + \vA_{22} \vB_{21} & \vA_{21}\vB_{12} + \vA_{22} \vB_{22} \\ \end{bmatrix}.\]

So no matter how $\vB$ is sharded, we need to perform communication such that we compute each of the matrix products appearing above. We can think about two cases:

  • No communication involving $\vA$: Here, we can use the RHS of the above display to tell us what needs to be computed. Processor $(0, 0)$ needs the first (block) row of $\vB$, processor $(0, 1)$ needs the second (block) row of $\vB$, etc.—so the sharding of $\vB$ had better support this, or communication will be necessary. Regardless, some communication between devices is necessary to accumulate the block matrix products: e.g., we need to add results from processor $(0, 0)$ and $(0, 1)$ to get the top-left block of the product.
  • We can re-structure how $\vA$ is sharded: It might be possible, given a specific sharding for $\vB$, to have a more efficient computation by sharding $\vA$ differently (and vice versa). Thinking in terms of block partitions, this is the same as observing that there are many ways of partitioning a matrix into blocks, each of which gives a different decomposition for the matrix product. For example, this is also a compatible partitioning (with suitable definitions of the blocks):

    \[\begin{equation}\label{eq:nice-sharding} \vA \vB = \begin{bmatrix} \vA_{1} \\ \vA_{2} \end{bmatrix} \begin{bmatrix} \vB_{1} & \vB_{2} \end{bmatrix} = \begin{bmatrix} \vA_{1}\vB_{1} & \vA_{1}\vB_{2} \\ \vA_{2}\vB_{1} & \vA_{2}\vB_{2} \end{bmatrix}, \end{equation}\]

    and so is

    \[\vA \vB = \begin{bmatrix} \vA_{1} & \vA_{2} \end{bmatrix} \begin{bmatrix} \vB_{1} \\ \vB_{2} \end{bmatrix} = \vA_1 \vB_1 + \vA_2 \vB_2.\]

    Re-sharding requires communication, though, and so this may or may not be more efficient.

As an exercise, verify that $\eqref{eq:nice-sharding}$ corresponds to a zero-communication sharding for a $2 \times 2$ output sharding.3

Now, if $\vA$ and $\vB$ both have their own sharding, in general these correspond to (possibly incompatible) block matrix partitions of each matrix. For each individual partition, we can, as above, write out the blockwise matrix products and accumulation operations needed to compute the product, and try to intuit what computation/communication should happen—in particular involving re-sharding one or both arrays. In JAX, the XLA compiler will be recruited to perform this optimization; as algorithm designers, we can try to provide sensible shardings up front in order to make sure we end up with an efficient system.

Types of Inter-Device Communication with Sharded Computation

The matrix multiplication example suggests that different forms of communication naturally arise in distributed computation. We will describe the ones that are implemented in XLA here. The JAX scaling book (Austin et al., 2025) gives a nice way to think about matrix multiplication of sharded arrays using a variant of named-axis notation. Since this is closer to code, we’ll describe this briefly below.

Communication Primitives through Matrix Multiplication

We saw in the previous section how to think about sharding of 2D arrays in terms of block matrices:

  • Sharded array axes are partitioned; non-sharded array axes aren’t.
  • Any mesh axis that isn’t used means the partitioned array is copied across that axis. This gives us a scheme to map a block matrix to the corresponding arrangement of that matrix across physical devices.

With named index notation, things are more compact. Consider a 2D mesh with axes $x$ and $y$, and let us write $A[I, J]$ for the 2D array $\vA$ (likewise for $\vB$) we looked at in the previous section. We write $A[I_x, J]$ (etc.) to denote the sharding of $A$ along its first dimension with respect to the mesh axis $x$. The use of $I$, $J$, etc.\ for ‘dummy’ indices is in analogy to einsum-type notation, which is very useful for tensor operations.

Now consider a matrix product of sharded matrices $A[I_x, J] \cdot B[J, K_y]$. Suppose we want the output, say $\vC$, to be sharded as $C[I_x, K_y]$. Then this matrix product can be computed with no communication: to see why, think in terms of the correspondence to block matrix partitions and their associated map to device arrangements in the previous section. However, for an unsharded output $C[I, K]$, it would be necessary to perform some communication—either calculating $C[I_x, K_y]$ and then sharing the local results among all devices, or something else.

The following is a convenient way to think about this, following the previous example:

  • Given a sharding of matrix multiplicands $A$, $B$, there is a ‘natural’ output sharding whenever the contracting dimension is unsharded. (As above, this comes by ‘removing’ the contracting axis, just like in einsum notation.) If this sharding is invalid, communication is necessary.
  • If the contracting dimension is sharded, different types of communication are necessary to turn the multiplication into one with an unsharded contracting dimension. We can identify typical ‘best practices’ for what communication type to use based on how the contracting dimension is sharded.
  • If an output sharding is specified and it is not the ‘natural’ one, some communication is required to achieve it. We can use the techniques from either of the two previous bullets to do this.

We’ll describe these different cases below.

Case 1: Unsharded Contracting Dimension, Valid Output

As above, the ‘natural’ output sharding is just an einsum-style replacement:

\[A[I_x, J] \cdot B[J, K] \to C[I_x, K].\]

If this natural output sharding is invalid, we need to do something else (see below).

Case 2: One Multiplicand Sharded Along Contracting Dimension (AllGather)

For a product like

\[A[I, J_x] \cdot B[J, K] \to C[I, K],\]

it isn’t possible to directly perform the multiplication, because in order to contract, we either need to perform a series of block-matrix multiplies, then accumulate the results (see below), or have the entire contracting dimension of both arrays available on all devices (see above).

We use AllGather for this, which removes a specified sharded axis:

\[\mathrm{AllGather}_{x}(A[I, J_x]) \to A[I, J].\]

For a compound mesh, AllGather can also just remove one sub-axis:

\[\mathrm{AllGather}_{x}(A[I, J_{xy}]) \to A[I, J_y].\]

AllGather can also be used to satisfy certain output shardings that aren’t naturally compatible with the input shardings. For example, in the product

\[A[I_x, J] \cdot B[J, K] \to C[I, K],\]

one removes the sharded axis either before or after the computation in order to achieve the desired output sharding.

AllGather Cost

On TPUs with enough chips to form a cube, AllGather can be performed using algorithms that exploit the toroidal interconnectivity of TPU devices. For example, to AllGather a single axis, the following procedure can be used:[^5]

  1. Initialize a local buffer on each device with its local shard.
  2. Send the buffer contents to the next device; overwrite it with the contents of the previous device.
  3. Store the received shard in memory.
  4. Repeat.

This takes at most $N/2$ rounds of communication if the mesh axis has size $N$. For an array of size $V$ bytes, each round sends $V/N$ bytes over each link. Given an ICI bandwidth of $W$ (unidirectional),4 the total time this takes is no more than

\[\frac{N}{2} \cdot \frac{V}{NW} = \frac{V}{2W} \text{ s},\]

which is independent of the size of the mesh. However, one must also take into account the intrinsic overhead of ICI communication. Each operation takes around $T_{\min} = 10^{-6} \text{ s}$ (Austin et al., 2025), so the total time is actually only bounded by

\[\max \set*{\frac{V}{2W}, \frac{NT_{\min}}{2}}.\]

In particular, there are latency-bound (small array size) and throughput-bound (large array size) regimes of communication.

AllGathering a compound axis (e.g., $xy$) is similar: one can utilize some per-device state to efficiently communicate data across the entire multidimensional mesh. Since each device can then communicate along each mesh axis simultaneously, the bandwidth increases by a factor proportional to the number of mesh axes, but the maximum possible latency also increases, since there are more devices. The total time is no larger than

\[\max \set*{\frac{V}{2n_{\mathrm{axes}}W}, \frac{N_{\mathrm{total}}T_{\min}}{2}},\]

where $N_{\mathrm{total}}$ is the product of the axes lengths.

Case 3: Both Multiplicands Sharded Along Contracting Dimension (AllReduce)

In this case, we want to compute a matrix product like

\[A[I, J_x] \cdot B[J_x, K] \to C[I, K].\]

Thinking in terms of block matrices, we know that computing this matrix product entails a sum of the block matrix multiplications corresponding to the shards of both $A$ and $B$, and since the contracting axis is sharded, we can:

  1. Compute per-device multiplications of the sharded matrices.
  2. Add these up.

The communication primitive that performs this reduction is called AllReduce. It is similar to AllGather, but it sums up the components being communicated, and distributes the resulting sum to all devices.

We will use the notation from (Austin et al., 2025): we write

\[C[I, K]\{ U_x \}\]

to denote an array $C$ which is “unreduced” over the mesh axis $x$, i.e. the partial products are sharded over the $x$ mesh axis, which gives

\[\mathrm{AllReduce}_y\left( A[I_x, J]\{ U_y \} \right) \to A[I_x, J]\]

as the signature for AllReduce.

Decomposition in Terms of ReduceScatter

Generally, an AllReduce is about two times as expensive as an AllGather. One way to see why is to decompose it into another useful communication primitive: the ReduceScatter operation.

ReduceScatter has signature

\[\mathrm{ReduceScatter}_{y, J}\left( A[I_x, J]\{ U_y \} \right) \to A[I_x, J_y].\]

It takes an array unreduced along a mesh axis, then shards the array along a specified array axis with respect to that mesh axis. To compute an AllReduce, we can compose a ReduceScatter with an AllGather:

\[\mathrm{AllReduce}_y\left( A[I_x, J]\{ U_y \} \right) = \mathrm{AllGather}_{y}\left( \mathrm{ReduceScatter}_{y, J}\left( A[I_x, J]\{ U_y \} \right) \right).\]

For multidimensional arrays, there is a possibility to choose the array axis that is ReduceScattered to optimize performance.

ReduceScatter can be performed by an algorithm similar to the approach we discussed for AllGather. The process is somewhat difficult to describe algorithmically, but there is a very good pictorial representation in (Austin et al., 2025).

Reduce Scatter

The upshot is that this similarity persists to the latency and throughput analysis of the algorithm, making it identical to that of AllGather. This demonstrates the cost of AllReduce as twice that of AllGather.

Aside: The Backward Pass

There is a more fundamental relationship between the AllGather and ReduceScatter operations, suggested by the representation of AllReduce above: they can be represented as adjoints of one another, which means that an AllGather in the forward pass will trigger a ReduceScatter in the backward pass!

To see this, let’s go back to our more visual block matrix example. As above, consider a simplified setting where a matrix is sharded along a 1D mesh corresponding to the first array dimension. With mesh size $3$ for simplicity, we represent this as the block matrix

\[\vA = \begin{bmatrix} \vA_0 \\ \vA_1 \\ \vA_2 \end{bmatrix}.\]

Remember that the partitioning of the block matrix corresponds to the device layout in this model. Now, if we perform an AllGather, we have the ($3\times$ larger) block matrix output

\[\mathrm{AllGather}(\vA) = \begin{bmatrix} \vA \\ \vA \\ \vA \end{bmatrix} = \begin{bmatrix} \begin{bmatrix} \vA_0 \\ \vA_1 \\ \vA_2 \end{bmatrix} \\ \begin{bmatrix} \vA_0 \\ \vA_1 \\ \vA_2 \end{bmatrix} \\ \begin{bmatrix} \vA_0 \\ \vA_1 \\ \vA_2 \end{bmatrix} \end{bmatrix}.\]

Let’s imagine we perform some subsequent computations on the AllGathered output, and we want to compute a backward pass. In general, each individual device will have different gradients, which we can write in partitioned form as

\[\begin{bmatrix} \Delta \vA^0 \\ \Delta \vA^1 \\ \Delta \vA^2 \end{bmatrix} = \begin{bmatrix} \begin{bmatrix} \Delta \vA_0^0 \\ \Delta \vA_1^0 \\ \Delta \vA_2^0 \end{bmatrix} \\ \begin{bmatrix} \Delta \vA_0^1 \\ \Delta \vA_1^1 \\ \Delta \vA_2^1 \end{bmatrix} \\ \begin{bmatrix} \Delta \vA_0^2 \\ \Delta \vA_1^2 \\ \Delta \vA_2^2 \end{bmatrix} \end{bmatrix}.\]

Then because AllGather behaves in the sense of linear algebra as a copying operation, its backward pass involves its adjoint operation, which is a sum. We end up with5

\[\mathrm{dAllGather}^* \left( \begin{bmatrix} \Delta \vA^0 \\ \Delta \vA^1 \\ \Delta \vA^2 \end{bmatrix} \right) = \begin{bmatrix} \sum_{i=1}^3\Delta \vA^i_0 \\ \sum_{i=1}^3 \Delta \vA^i_1 \\ \sum_{i=1}^3 \Delta \vA^i_2 \end{bmatrix}.\]

This is exactly what we’d get if we:

  1. Interpret the per-device gradients as unreduced partial sums with respect to the mesh axis of the AllGather;
  2. Perform a ReduceScatter with respect to this mesh axis and the array axis that was sharded before the AllGather.

So, an AllGather in the forward pass produces a corresponding ReduceScatter in the backward pass!

Similarly, if there is a ReduceScatter in the forward pass that produces an array sharded along one of its axes, then in the backward pass, there will be an AllGather operation along that axis in the backward pass. And finally, because an AllReduce can be expressed as the composition of an AllGather with a ReduceScatter, the backward pass of an AllReduce is an AllReduce!

Case 4: Sharded Non-Contracting Dimensions, Same Axis

A direct reduction to the natural sharding for this case would lead to an invalid sharding:

\[A[I_x, J] \cdot B[J, K_x] \to C[I_x, K_x]\]

A mesh axis cannot appear multiple times in a sharding specification.

To resolve this, we just AllGather one of the axes in the multiplicands:

\[\begin{split} \mathrm{AllGather}_x(A[I_x, J]) \to A[I, J] \\ A[I, J] \cdot B[J, K_x] \to C[I, K_x]. \end{split}\]

We can select which depending on the downstream shardings needed.

AllToAll Communication

This primitive can be thought of as a resharding operation. It has signature

\[\mathrm{AllToAll}_{x, J}( A[I_x, J] ) \to A[I, J_x].\]

AllToAll is actually significantly cheaper than an AllGather, by a factor of 4. This is another one where the excellent picture from (Austin et al., 2025) describes where this efficiency is coming from well.

All To All

Notice that the message that is being sent over each link is composed of pieces of different lengths from each shard. The communication is balanced, but later messages are shorter, which may cause issues with latency-boundedness of the communication.

Example: Communication Primitives in JAX

Let’s test out what we learned above in the JAX sandbox we started this post with. For slightly more granular control of sharding, we’ll switch to Explicit mode sharding (see e.g. (JAX Team, 2025)).

mesh = jax.make_mesh(
    (2, 2), ("x", "y"), axis_types=(jax.sharding.AxisType.Explicit,) * 2
)  # 1x v4 tpu host
jax.set_mesh(mesh)

A = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P("x", "y")))

jax.typeof(A)
Output:
ShapedArray(bfloat16[2048@x,8192@y])

AllGather

First, a simple AllGather implementation along the y mesh axis—we compile the resharding operation and inspect the “high-level operations” (HLO) to see what actually gets implemented on the device.

@jax.jit
def all_gather(x):
    return jax.sharding.reshard(x, jax.NamedSharding(mesh, jax.P("x")))


lowered = all_gather.lower(A)
compiled = lowered.compile()
B = compiled(A)

print("Output type: ", jax.typeof(B))
print()
print("StableHLO:\n", lowered.as_text())
print()
print("XLA HLO:\n", compiled.as_text())
Output:
Output type:  bfloat16[2048@x,8192]

StableHLO:
 module @jit_all_gather attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["x"=2, "y"=2]>
  func.func public @main(%arg0: tensor<2048x8192xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<2048x8192xbf16> {jax.result_info = "result", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) {
    %0 = sdy.sharding_constraint %arg0 <@mesh, [{"x"}, {}]> : tensor<2048x8192xbf16>
    return %0 : tensor<2048x8192xbf16>
  }
}

XLA HLO:
 HloModule jit_all_gather, is_scheduled=true, entry_computation_layout={(bf16[1024,4096]{1,0:T(8,128)(2,1)})->bf16[1024,8192]{1,0:T(8,128)(2,1)}}, num_partitions=4
ENTRY %main.5_spmd (param: bf16[1024,4096]) -> bf16[1024,8192] {
  %param = bf16[1024,4096]{1,0:T(8,128)(2,1)} parameter(0), sharding={devices=[2,2]<=[4]}, metadata={op_name="x"}
  %all-gather = bf16[1024,8192]{1,0:T(8,128)(2,1)S(3)} all-gather(%param), channel_id=1, replica_groups=[2,2]<=[4], dimensions={1}, use_global_device_ids=true, metadata={op_name="jit(all_gather)/reshard" source_file="/tmp/ipykernel_794398/1925679859.py" source_line=3}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"CUSTOM","id":"0"},"scoped_memory_configs":[],"collective_algorithm_config":{"emitter":"1DAllGatherNonMajorDim","debug":"\ngroup_size = 2 \nhas_reordering_map: false \nper_stride_size = 65536 bytes \nshard_size = 8388608 bytes "},"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"0"}],"retry_config":{"retry_count":"0"}}
  ROOT %copy.3 = bf16[1024,8192]{1,0:T(8,128)(2,1)} copy(%all-gather), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16","64"],"input_window_bounds":[],"estimated_cycles":"58232","iteration_bounds":["8","1"]},"megacore_config":{"megacore_split_dim":"0"},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"8388608"}],"retry_config":{"retry_count":"0"}}
}

The first representation above (output of lower) is in XLA’s StableHLO language. It is slightly imposing, but there is a clearly written spec that helps with parsing it after a bit of studying. It’s still too high level to clearly show what communication primitives are being generated, though.

The second representation (output of compile) is the result of the XLA compiler’s first device-independent compilation stage; it’s in its own dialect of HLO, which is different from StableHLO and doesn’t have as easily-accessible of a public spec.6 Here’s how to parse it above:

  • Types are relatively self-explanatory, aside from the text in braces—the first sequence of integers here is most important, and it specifies the memory layout of the array (see layout.h in the XLA source). Axis indices are listed in minor-to-major order (i.e., 1, 0 denotes row-major layout).
  • The all-gather semantics are clear at a high level: we gather along dimension $1$ of our array. The replica_groups argument specifies which devices communicate in the gather (see tile_assignment.h in the XLA source)—compare to our input %param’s sharding specified above to see that these agree.
  • Independent of these lower-level details, we can see that the input and output shapes are as we expect.

AllToAll

We can test out AllToAll with a similar high-level API. Let’s start by transposing the sharding axis of the output of the previous AllGather.

@jax.jit
def all_to_all(x):
    return jax.sharding.reshard(x, jax.NamedSharding(mesh, jax.P(None, "x")))


lowered = all_to_all.lower(B)
compiled = lowered.compile()
C = compiled(B)

print("Output type: ", jax.typeof(C))
print()
print("StableHLO:\n", lowered.as_text())
print()
print("XLA HLO:\n", compiled.as_text())

The HLO that we get is more complex:

Output:
Output type:  bfloat16[2048,8192@x]
StableHLO:
 module @jit_all_to_all attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["x"=2, "y"=2]>
  func.func public @main(%arg0: tensor<2048x8192xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<2048x8192xbf16> {jax.result_info = "result", sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) {
    %0 = sdy.sharding_constraint %arg0 <@mesh, [{}, {"x"}]> : tensor<2048x8192xbf16>
    return %0 : tensor<2048x8192xbf16>
  }
}
XLA HLO:
 HloModule jit_all_to_all, is_scheduled=true, entry_computation_layout={(bf16[1024,8192]{1,0:T(8,128)(2,1)})->bf16[2048,4096]{1,0:T(8,128)(2,1)}}, num_partitions=4
%fused_computation (param_0.1: bf16[128,2,8,4096]) -> bf16[1024,2,4096] {
  %param_0.1 = bf16[128,2,8,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0)
  %copy.4 = bf16[128,2,8,4096]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.1), metadata={op_name="jit(all_to_all)/reshard" source_file="/tmp/ipykernel_874047/2971478236.py" source_line=3}
  ROOT %bitcast.5 = bf16[1024,2,4096]{2,0,1:T(8,128)(2,1)S(3)} bitcast(%copy.4), metadata={op_name="jit(all_to_all)/reshard" source_file="/tmp/ipykernel_874047/2971478236.py" source_line=3}
}
ENTRY %main.5_spmd (param: bf16[1024,8192]) -> bf16[2048,4096] {
  %param = bf16[1024,8192]{1,0:T(8,128)(2,1)} parameter(0), sharding={devices=[2,1,2]<=[4] last_tile_dim_replicate}, metadata={op_name="x"}
  %bitcast.7 = bf16[128,2,8,4096]{3,2,1,0:T(8,128)(2,1)} bitcast(%param)
  %copy_bitcast_fusion = bf16[1024,2,4096]{2,0,1:T(8,128)(2,1)S(3)} fusion(%bitcast.7), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(all_to_all)/reshard" source_file="/tmp/ipykernel_874047/2971478236.py" source_line=3}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","32","32"],"input_window_bounds":[],"estimated_cycles":"78212","iteration_bounds":["2","4","1"]},"megacore_config":{"megacore_split_dim":"1"},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"8388608"}],"retry_config":{"retry_count":"0"}}
  %all-to-all = bf16[1024,2,4096]{2,0,1:T(8,128)(2,1)S(3)} all-to-all(%copy_bitcast_fusion), channel_id=1, replica_groups=[2,2]<=[2,2]T(1,0), dimensions={1}, metadata={op_name="jit(all_to_all)/reshard" source_file="/tmp/ipykernel_874047/2971478236.py" source_line=3}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"GLOBAL","id":"-1"},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"262144"}],"retry_config":{"retry_count":"0"}}
  %bitcast.6 = bf16[2048,4096]{1,0:T(8,128)(2,1)S(3)} bitcast(%all-to-all)
  ROOT %copy.6 = bf16[2048,4096]{1,0:T(8,128)(2,1)} copy(%bitcast.6), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32","32"],"input_window_bounds":[],"estimated_cycles":"58232","iteration_bounds":["8","1"]},"megacore_config":{"megacore_split_dim":"0"},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"8388608"}],"retry_config":{"retry_count":"0"}}
}

Some amount of the above amounts to memory optimizations that are necessary to actually implement the communication in physical memory.7 We can more easily interpret what’s going on by thinking in terms of block matrices again: recall that we started with a $2048 \times 8192$ array $\vA$, sharded two ways:

\[\vA = \begin{bmatrix} \vA_{00} & \vA_{01} \\ \vA_{10} & \vA_{11} \end{bmatrix},\]

and we performed an AllGather along the second axis, leading to (we’ll use $\oplus_i$ for concatenation along dimension $i$)

\[\vB = \begin{bmatrix} \vA_{00} \oplus_1 \vA_{01} & \vA_{00} \oplus_1 \vA_{01} \\ \vA_{10} \oplus_1 \vA_{11} & \vA_{10} \oplus_1 \vA_{11} \end{bmatrix}.\]

The output of the AllToAll reshards along the trailing axis, but with respect to the 'x' mesh axis, which gives

\[\vC = \begin{bmatrix} \vA_{00} \oplus_0 \vA_{10} & \vA_{00} \oplus_0 \vA_{10} \\ \vA_{01} \oplus_0 \vA_{11} & \vA_{01} \oplus_0 \vA_{11} \end{bmatrix}.\]

To go from $\vB$ to $\vC$, we can do the following, which is essentially what the HLO above is doing (modulo memory operations): unbind the concatenated arrays, to give (in a rough approximation of tensor notation)

\[\begin{bmatrix} (\vA_{00}, \vA_{01}) & (\vA_{00}, \vA_{01}) \\ (\vA_{10}, \vA_{11}) & (\vA_{10}, \vA_{11}) \end{bmatrix},\]

then communicate ‘vertically’, swapping $01$ and $10$ on the left and right, to get

\[\begin{bmatrix} (\vA_{00}, \vA_{10}) & (\vA_{00}, \vA_{10}) \\ (\vA_{01}, \vA_{11}) & (\vA_{01}, \vA_{11}) \end{bmatrix}.\]

If we concatenate these in the right order, we end up with $\vC$! We can see that this is what the HLO is doing:

  • After some memory manipulations, an extra axis of size $2$ is split off from the second dimension, making the local array shaped like $1024 \times 2 \times 4096$.
  • The all-to-all function has to be interpreted relative to the input sharding. Here {devices=[2,1,2]<=[4] last_tile_dim_replicate} denotes a [2, 1] tile size, replicated twice; <=[4] is equivalent to a device list {0, 1, 2, 3}. This is as we expect—the input is sharded over 'x'.
  • Now, in all-to-all, replica_groups=[2,2]<=[2,2]T(1,0) specifies how the communication is done across devices (relative to the sharding previously specified). This corresponds to a specification of {{0, 2}, {1, 3}} (see hlo_sharding.h in the XLA source), which means that each replica’s array gets split into two parts along dimension $1$ (the axis of length $2$), then communication occurs across the specified groups. This gets us what we wanted above.

It’s slightly interesting to note that different input sharding and an AllToAll leads to different memory operations in the HLO:

@jax.jit
def all_to_all(x):
    return jax.sharding.reshard(x, jax.NamedSharding(mesh, jax.P("x")))


lowered = all_to_all.lower(C)
compiled = lowered.compile()
D = compiled(C)

print("Output type: ", jax.typeof(D))
print()
print("StableHLO:\n", lowered.as_text())
print()
print("XLA HLO:\n", compiled.as_text())
Output:
Output type:  bfloat16[2048@x,8192]
StableHLO:
 module @jit_all_to_all attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["x"=2, "y"=2]>
  func.func public @main(%arg0: tensor<2048x8192xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) -> (tensor<2048x8192xbf16> {jax.result_info = "result", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) {
    %0 = sdy.sharding_constraint %arg0 <@mesh, [{"x"}, {}]> : tensor<2048x8192xbf16>
    return %0 : tensor<2048x8192xbf16>
  }
}
XLA HLO:
 HloModule jit_all_to_all, is_scheduled=true, entry_computation_layout={(bf16[2048,4096]{1,0:T(8,128)(2,1)})->bf16[1024,8192]{1,0:T(8,128)(2,1)}}, num_partitions=4
ENTRY %main.5_spmd (param: bf16[2048,4096]) -> bf16[1024,8192] {
  %param = bf16[2048,4096]{1,0:T(8,128)(2,1)} parameter(0), sharding={devices=[1,2,2]<=[4] last_tile_dim_replicate}, metadata={op_name="x"}
  %bitcast.5 = bf16[2,1024,4096]{2,1,0:T(8,128)(2,1)} bitcast(%param)
  %all-to-all = bf16[2,1024,4096]{2,1,0:T(8,128)(2,1)S(3)} all-to-all(%bitcast.5), channel_id=1, replica_groups=[2,2]<=[2,2]T(1,0), dimensions={0}, metadata={op_name="jit(all_to_all)/reshard" source_file="/tmp/ipykernel_874047/814736672.py" source_line=3}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"GLOBAL","id":"-1"},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"0"}],"retry_config":{"retry_count":"0"}}
  %copy.3 = bf16[2,1024,4096]{1,2,0:T(8,128)(2,1)S(3)} copy(%all-to-all), metadata={op_name="jit(all_to_all)/reshard" source_file="/tmp/ipykernel_874047/814736672.py" source_line=3}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","48","8"],"input_window_bounds":["1","128","3"],"estimated_cycles":"47696","iteration_bounds":["1","11","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"3145728"}],"retry_config":{"retry_count":"0"}}
  %bitcast.4 = bf16[1024,8192]{0,1:T(8,128)(2,1)S(3)} bitcast(%copy.3)
  ROOT %copy.4 = bf16[1024,8192]{1,0:T(8,128)(2,1)} copy(%bitcast.4), metadata={op_name="jit(all_to_all)/reshard" source_file="/tmp/ipykernel_874047/814736672.py" source_line=3}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["64","6"],"input_window_bounds":["96","4"],"estimated_cycles":"67782","iteration_bounds":["1","11"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"3182592"}],"retry_config":{"retry_count":"0"}}
}

The HLO for this operation is more concise, but neither is able to avoid a copy operation.

AllReduce and ReduceScatter

To trigger an AllReduce/ReduceScatter, let’s follow our matrix multiplication example above—we’ll multiply two arrays that are both sharded along the contracting dimension. In the AllToAll section, we’ve generated C, which is sharded along 'x' in its 1st dimension. We can try computing a Gram matrix associated to C, and specifying the out_sharding (singular!) in the dot operation to avoid the compiler complaining:

@jax.jit
def gram(x):
    return jnp.dot(x, x.mT, out_sharding=jax.NamedSharding(mesh, jax.P()))


lowered = gram.lower(C)
compiled = lowered.compile()
E = compiled(C)


print("Output type: ", jax.typeof(E))
print()
print("StableHLO:\n", lowered.as_text())
print()
print("XLA HLO:\n", compiled.as_text())
Output:
Output type:  bfloat16[2048,2048]
StableHLO:
 module @jit_gram attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["x"=2, "y"=2]>
  func.func public @main(%arg0: tensor<2048x8192xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) -> (tensor<2048x2048xbf16> {jax.result_info = "result", sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2048x8192xbf16>) -> tensor<8192x2048xbf16>
    %1 = sdy.sharding_constraint %0 <@mesh, [{"x"}, {}]> : tensor<8192x2048xbf16>
    %2 = stablehlo.dot_general %arg0, %1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2048x8192xbf16>, tensor<8192x2048xbf16>) -> tensor<2048x2048xbf16>
    %3 = sdy.sharding_constraint %2 <@mesh, [{}, {}]> : tensor<2048x2048xbf16>
    return %3 : tensor<2048x2048xbf16>
  }
}
XLA HLO:
 HloModule jit_gram, is_scheduled=true, entry_computation_layout={(bf16[2048,4096]{1,0:T(8,128)(2,1)})->bf16[2048,2048]{1,0:T(8,128)(2,1)}}, num_partitions=4
%add.clone (x.1: bf16[], y.1: bf16[]) -> bf16[] {
  %y.1 = bf16[]{:T(256)} parameter(1)
  %x.1 = bf16[]{:T(256)} parameter(0)
  ROOT %add.1 = bf16[]{:T(256)} add(%x.1, %y.1)
}
%bitcast_fusion (bitcast_input: bf16[2048,4096]) -> bf16[2048,4096] {
  %bitcast_input = bf16[2048,4096]{1,0:T(8,128)(2,1)S(3)} parameter(0)
  ROOT %bitcast = bf16[2048,4096]{1,0:T(8,128)(2,1)} bitcast(%bitcast_input)
}
%bitcast_fusion.1 (bitcast_input.1: bf16[2048,4096]) -> bf16[2048,4096] {
  %bitcast_input.1 = bf16[2048,4096]{1,0:T(8,128)(2,1)S(3)} parameter(0)
  ROOT %bitcast.1 = bf16[2048,4096]{1,0:T(8,128)(2,1)} bitcast(%bitcast_input.1)
}
%fused_computation (param_0: bf16[2048,4096]) -> bf16[2048,2048] {
  %param_0 = bf16[2048,4096]{1,0:T(8,128)(2,1)S(3)} parameter(0)
  %fusion.1 = bf16[2048,4096]{1,0:T(8,128)(2,1)} fusion(%param_0), kind=kLoop, calls=%bitcast_fusion
  %fusion.2 = bf16[2048,4096]{1,0:T(8,128)(2,1)} fusion(%param_0), kind=kLoop, calls=%bitcast_fusion.1
  ROOT %convolution.1 = bf16[2048,2048]{1,0:T(8,128)(2,1)S(3)} convolution(%fusion.1, %fusion.2), dim_labels=bf_oi->bf, metadata={op_name="jit(gram)/dot_general" source_file="/tmp/ipykernel_874047/670048905.py" source_line=3}
}
ENTRY %main.8_spmd (param: bf16[2048,4096]) -> bf16[2048,2048] {
  %param = bf16[2048,4096]{1,0:T(8,128)(2,1)} parameter(0), sharding={devices=[1,2,2]<=[4] last_tile_dim_replicate}, metadata={op_name="x"}
  %copy-start = (bf16[2048,4096]{1,0:T(8,128)(2,1)S(3)}, bf16[2048,4096]{1,0:T(8,128)(2,1)}, u32[]{:S(2)}) copy-start(%param), cross_program_prefetch_index=0
  %copy-done = bf16[2048,4096]{1,0:T(8,128)(2,1)S(3)} copy-done(%copy-start)
  %fusion = bf16[2048,2048]{1,0:T(8,128)(2,1)S(3)} fusion(%copy-done), kind=kOutput, calls=%fused_computation, metadata={op_name="jit(gram)/dot_general" source_file="/tmp/ipykernel_874047/670048905.py" source_line=3}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":["64","32"],"output_window_bounds":["64","4"],"input_window_bounds":["64","32"],"estimated_cycles":"149352","iteration_bounds":["4","4","1"]},"megacore_config":{"megacore_split_dim":"0"},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"5734400"}],"retry_config":{"retry_count":"0"},"convolution_algorithm_config":{"emitter":"EmitAllBatchInSublanes"}}
  ROOT %all-reduce = bf16[2048,2048]{1,0:T(8,128)(2,1)} all-reduce(%fusion), channel_id=1, replica_groups=[2,2]<=[2,2]T(1,0), use_global_device_ids=true, to_apply=%add.clone, metadata={op_name="jit(gram)/dot_general" source_file="/tmp/ipykernel_874047/670048905.py" source_line=3}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"CUSTOM","id":"0"},"scoped_memory_configs":[{"memory_space":"0","offset":"0","size":"67108864"}],"collective_algorithm_config":{"emitter":"RotatedPincerEmitter","strategy":"UniDirection1DRingStrategy","debug":"\nUniDirection1DRingStrategy{colors:2 phases:1 cores:{2},{2} nophase0:0 reserved_sflags:0 cross_module_on_2d_plane:0 has_reordering_map:0 use_routing_table_indices:0}"},"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"12582912"}],"retry_config":{"retry_count":"0"}}
}

The above is relatively clear! A copy operation accounts for the fact that we’re multiplying the matrix with itself to calculate the Gram matrix, then the actual multiplication is performed with the XLA convolution operation. It’s performed on each local shard, and then the local shards are accumulated and shared with an AllReduce.

We expect to be able to trigger a ReduceScatter instead of an AllReduce by asking the computation output to be sharded in a particular way (rather than fully replicated, as above). Let’s try to verify this:

@jax.jit
def dot_custom(x, y):
    return jnp.dot(x, y, out_sharding=jax.NamedSharding(mesh, jax.P('x', 'y')))

F = jnp.zeros((8192, 4096), dtype=jnp.bfloat16, out_sharding=jax.NamedSharding(mesh, jax.P('y')))

lowered = dot_custom.lower(A, F)
compiled = lowered.compile()
G = compiled(A, F)


print("Output type: ", jax.typeof(G))
print()
print("StableHLO:\n", lowered.as_text())
print()
print("XLA HLO:\n", compiled.as_text())
Output:
Output type:  bfloat16[2048@x,4096@y]
StableHLO:
 module @jit_dot_custom attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["x"=2, "y"=2]>
  func.func public @main(%arg0: tensor<2048x8192xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<8192x4096xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<2048x4096xbf16> {jax.result_info = "result", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
    %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2048x8192xbf16>, tensor<8192x4096xbf16>) -> tensor<2048x4096xbf16>
    %1 = sdy.sharding_constraint %0 <@mesh, [{"x"}, {"y"}]> : tensor<2048x4096xbf16>
    return %1 : tensor<2048x4096xbf16>
  }
}
XLA HLO:
 HloModule jit_dot_custom, is_scheduled=true, entry_computation_layout={(bf16[1024,4096]{1,0:T(8,128)(2,1)}, bf16[4096,4096]{1,0:T(8,128)(2,1)})->bf16[1024,2048]{1,0:T(8,128)(2,1)}}, allow_spmd_sharding_propagation_to_parameters={false,false}, num_partitions=4
%add.1.clone (x.3: bf16[], y.3: bf16[]) -> bf16[] {
  %y.3 = bf16[]{:T(256)} parameter(1)
  %x.3 = bf16[]{:T(256)} parameter(0)
  ROOT %add.3 = bf16[]{:T(256)} add(%x.3, %y.3)
}
%all-reduce-scatter (input: bf16[1024,4096]) -> bf16[1024,2048] {
  %input = bf16[1024,4096]{1,0:T(8,128)(2,1)S(3)} parameter(0)
  %all-reduce.2 = bf16[1024,4096]{1,0:T(8,128)(2,1)} all-reduce(%input), channel_id=3, replica_groups={{0,1},{2,3}}, use_global_device_ids=true, to_apply=%add.1.clone, frontend_attributes={from-cross-replica-sharding="true"}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"CUSTOM","id":"0"},"scoped_memory_configs":[],"used_scoped_memory_configs":[]}
  %constant.13 = u32[] constant(0)
  %constant.shard_id_table = u32[4]{0:T(128)} constant({0, 1, 0, 1})
  %partition-id.1 = u32[] partition-id()
  %dynamic-slice.3 = u32[1]{0:T(128)} dynamic-slice(%constant.shard_id_table, %partition-id.1), dynamic_slice_sizes={1}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"}],"is_index_aligned":[false]},"used_scoped_memory_configs":[]}
  %bitcast.1 = u32[]{:T(128)} bitcast(%dynamic-slice.3)
  %constant.14 = u32[] constant(2048)
  %multiply.3 = u32[]{:T(128)} multiply(%bitcast.1, %constant.14)
  ROOT %dynamic-slice.4 = bf16[1024,2048]{1,0:T(8,128)(2,1)} dynamic-slice(%all-reduce.2, %constant.13, %multiply.3), dynamic_slice_sizes={1024,2048}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294961151","ones":"0","bitwidth":"32"}],"is_index_aligned":[true,false]},"used_scoped_memory_configs":[]}
}
%bitcast_fusion (bitcast_input: bf16[1024,4096]) -> bf16[1024,4096] {
  %bitcast_input = bf16[1024,4096]{1,0:T(8,128)(2,1)} parameter(0)
  ROOT %bitcast.2 = bf16[1024,4096]{1,0:T(8,128)(2,1)} bitcast(%bitcast_input)
}
%bitcast_fusion.1 (bitcast_input.1: bf16[4096,4096]) -> bf16[4096,4096] {
  %bitcast_input.1 = bf16[4096,4096]{1,0:T(8,128)(2,1)S(3)} parameter(0)
  ROOT %bitcast.3 = bf16[4096,4096]{1,0:T(8,128)(2,1)} bitcast(%bitcast_input.1)
}
%fused_computation (param_0: bf16[1024,4096], param_1: bf16[4096,4096]) -> bf16[1024,4096] {
  %param_0 = bf16[1024,4096]{1,0:T(8,128)(2,1)} parameter(0)
  %fusion.2 = bf16[1024,4096]{1,0:T(8,128)(2,1)} fusion(%param_0), kind=kLoop, calls=%bitcast_fusion
  %param_1 = bf16[4096,4096]{1,0:T(8,128)(2,1)S(3)} parameter(1)
  %fusion.3 = bf16[4096,4096]{1,0:T(8,128)(2,1)} fusion(%param_1), kind=kLoop, calls=%bitcast_fusion.1
  ROOT %convolution.1 = bf16[1024,4096]{1,0:T(8,128)(2,1)S(3)} convolution(%fusion.2, %fusion.3), dim_labels=bf_io->bf, metadata={op_name="jit(dot_custom)/dot_general" source_file="/tmp/ipykernel_874047/871715024.py" source_line=3}
}
ENTRY %main.7_spmd (param: bf16[1024,4096], param.1: bf16[4096,4096]) -> bf16[1024,2048] {
  %param.1 = bf16[4096,4096]{1,0:T(8,128)(2,1)} parameter(1), sharding={devices=[2,1,2]<=[2,2]T(1,0) last_tile_dim_replicate}, metadata={op_name="y"}
  %copy-start = (bf16[4096,4096]{1,0:T(8,128)(2,1)S(3)}, bf16[4096,4096]{1,0:T(8,128)(2,1)}, u32[]{:S(2)}) copy-start(%param.1), cross_program_prefetch_index=0
  %param = bf16[1024,4096]{1,0:T(8,128)(2,1)} parameter(0), sharding={devices=[2,2]<=[4]}, metadata={op_name="x"}
  %copy-done = bf16[4096,4096]{1,0:T(8,128)(2,1)S(3)} copy-done(%copy-start)
  %fusion.1 = bf16[1024,4096]{1,0:T(8,128)(2,1)S(3)} fusion(%param, %copy-done), kind=kOutput, calls=%fused_computation, metadata={op_name="jit(dot_custom)/dot_general" source_file="/tmp/ipykernel_874047/871715024.py" source_line=3}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":["512","4"],"output_window_bounds":["32","4"],"input_window_bounds":["32","32"],"estimated_cycles":"144544","iteration_bounds":["8","4","1"]},"megacore_config":{"megacore_split_dim":"1"},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"5591040"}],"retry_config":{"retry_count":"0"},"convolution_algorithm_config":{"emitter":"EmitAllBatchInSublanes"}}
  ROOT %fusion = bf16[1024,2048]{1,0:T(8,128)(2,1)} fusion(%fusion.1), kind=kCustom, calls=%all-reduce-scatter, metadata={op_name="jit(dot_custom)/dot_general" source_file="/tmp/ipykernel_874047/871715024.py" source_line=3}, backend_config={"flag_configs":[],"scoped_memory_configs":[{"memory_space":"0","offset":"0","size":"67108864"}],"collective_algorithm_config":{"emitter":"SingleInputAllReduceScatterFusion","strategy":"StrategyRing","debug":"\nStrategyRing{colors:1 phases:1 cores:{2} nophase0:0 reserved_sflags:0 cross_module_on_2d_plane:0 has_reordering_map:0 use_routing_table_indices:0}\nStrategyRing{colors:1 phases:1 cores:{2} nophase0:0 reserved_sflags:0 cross_module_on_2d_plane:0 has_reordering_map:0 use_routing_table_indices:0}\nType: 1D phase_count: 1; color_count: 1; sharded_partitions: 4; original_shape: bf16[1024,4096]{1,0:T(8,128)(2,1)S(3)}; per_color_shard_counts: 2; color_dim: -1; sharding_dim: 1; sharding_type: minor; convert_all_gather_output_to_bf16: 0; formatting steps: ()\nall-reduce-scatter fusion ND: span_size:8192*k512Byte, shard count:2, span_count:2, total_size:16384*k512Byte, valid_granules:16384*k512Byte"},"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"12582912"}],"retry_config":{"retry_count":"0"}}
}

I actually haven’t found a way to generate a ReduceScatter operation in the HLO—if you know how to do this, please let me know! Above, I’ve picked different shardings and arrays, and we end up triggering an operation similar in effect to a ReduceScatter; but the HLO implementation instead uses all-reduce and dynamic-slice to split up the reduced array.

Conclusion

This turned out to be on the longer side! But we have two valuable takeaways:

  1. A linear-algebraic mental model for sharding in JAX—sharded dimensions partition data, and unsharded dimensions lead to replication in the underlying mesh! All of this can be understood in terms of block matrices and their algebra.
  2. The ability to read HLO and parse the communications primitives that arise therein when we manipulate sharded arrays.

We also recapped information from the TPU Scaling book about how to interpret performance tradeoffs between these different forms of communication.

Acknowledgments

Thanks to the TRC program for compute support.

  1. One way to generalize this to simultaneously considering different meshes is to consider that the mesh simply indexes different physical devices axes, or ‘merges’ of them. Hence if there are $M$ physical device axes (e.g., devices with 3D interconnects), one can count shardings across meshes by allowing the $k$ selected axes to be merged together. This is a noncommutative merging, as the traversal order for the devices will be different based on the ordering of the merged axes. 

  2. To go beyond this restriction, it suffices to think of the mapping between block indices and devices as being performed by a pre-configured mesh. Algebraically, the mesh axis corresponds to a sort of ‘block permutation matrix’ (i.e., a tensor product of a permutation matrix with the identity), which left-multiplies for sharding the first dimension of the matrix, and right-multiplies for the second. 

  3. Note that, algebraically, if we perform copying to represent a partitioned matrix in terms of devices, then elementwise matrix multiplication corresponds to the within-device matrix products!) 

  4. E.g., about 42 GB/s for v4 TPUs. 

  5. Since AllGather is linear, its derivative doesn’t involve any cached activations from the forward pass, so we omit these from the notation. 

  6. But the translator is open source, and the admissible operations can be gleaned from the source code (use Google’s LLM for help with this!). A high-level specification of operations can be found here

  7. Some high-level takeaways about these operations: think of bitcast like a view of the array, at least in this context, whereas copy doesn’t change the underlying tensor, but does change the memory layout. (Notice that the copy operation involves a transposition of the layout axes.)