system_ml / 02 · collectives lesson 2 / 19

Collectives — the seven primitives

Every distributed-training cost in this series is a sum of collective costs. This lesson is the vocabulary and the ring-AllReduce derivation that you'll reuse for the rest of the series.

The setting

N processes (one per GPU, by convention). Each one holds a buffer; they need to combine their buffers somehow. There are essentially seven shapes of "combine", and learning them once removes 80% of the magic from FSDP, TP, EP, and pipeline parallel. Vocabulary:

CollectiveInput on rank iOutput on rank iWhere it appears
Broadcast(only on root)Root's tensorInit weights, checkpoint reload
ReducexiΣ xj on root onlyLog aggregation
AllReducexiΣ xj on every rankDDP gradient sync, TP output
Gatherxi[x0, …, xN-1] on rootLog gather
AllGatherxi (a shard)Full concatenation on every rankFSDP forward (reconstruct full layer)
ReduceScatterxi (full)i-th piece of Σ xjFSDP backward (shard the gradient)
AllToAll[xi,0, …, xi,N-1][x0,i, …, xN-1,i]MoE routing, transpose

AllToAll is the one that catches people. Read its row carefully: rank i begins holding one chunk for each other rank, and ends holding one chunk from each other rank. It is a transpose of the (rank × chunk) matrix. Cost-wise it sends N-1 different messages instead of one; latency-bound rather than bandwidth-bound. Lesson 09 (MoE) is where it bites.

2D · the seven primitives, side by side

Each primitive is a different rearrangement of "rank × chunk" tiles. The animation below shows the input layout on the left, the output on the right, and the per-rank moves as colored arrows. Tab through them. Notice how AllGather and ReduceScatter are exact complements — they're two halves of one AllReduce.

Collective primitive gallery · pick one, step through
Four ranks (rows). Each cell is one chunk; color identifies the chunk's origin rank (or its summed value, for reduce ops). Step to animate the data movement.
primitive
input shape
output shape
where it shows up

The identity that explains FSDP

AllReduce(x)  =  AllGather( ReduceScatter(x) )

Stare at it. ReduceScatter reduces and then leaves each rank holding only its own slice of the sum. AllGather then redistributes those slices so every rank has the whole sum again. Two steps, same final answer as a one-step AllReduce.

Why care? Because if you only need your slice of the sum (because that's the slice you'll use next anyway — for instance because you only own that slice of the optimizer state, as in ZeRO-1), you can stop after ReduceScatter. You've now done half of an AllReduce. Inversely, if every rank starts with its own slice (because you sharded the weights, as in ZeRO-3) and you need the full thing for the next forward, you do an AllGather. The two halves of AllReduce can be charged separately by different optimizations. This is the trick lesson 05 turns into ZeRO.

Ring AllReduce — the bandwidth-optimal derivation

The naive AllReduce: send your tensor to a central root, sum, send back. Cost per rank: 2 · S bytes (one up, one down), but the root is hammered and the link to it is the bottleneck. With N ranks all sending simultaneously, the root's link receives N · S bytes — it scales linearly with cluster size. Bad.

Ring AllReduce does better. Arrange the N ranks in a logical ring: rank i sends only to rank (i+1) mod N. Each tensor is split into N chunks. Now two passes:

  1. ReduceScatter pass (N-1 steps). At step t, rank i sends chunk (i-t) mod N to its neighbour and receives chunk (i-t-1) mod N, which it adds to the chunk it already owns. After N-1 steps, rank i has the fully reduced version of chunk (i+1) mod N (the chunk it'll be "responsible for" in the second pass). Every other rank has the fully reduced version of their chunk.
  2. AllGather pass (N-1 steps). The chunk each rank now owns gets passed around the ring once. After N-1 more steps, every rank has every chunk fully reduced.

Per rank, in total: 2(N-1) messages, each of size S/N. Total bytes moved per rank:

bytes_per_rank  =  2 · (N − 1) · S / N  ≈  2 · S (large N)

This is the punchline of the lesson. Per-rank traffic of a ring AllReduce is independent of cluster size for large N. Adding a 257th GPU doesn't change how many bytes each existing GPU has to send. The link load is balanced across all N ring edges.

Why the factor of 2
ReduceScatter contributes (N−1)/N · S bytes, and AllGather contributes the same. Adding them gives 2(N−1)/N · S. The "2" is doing real work: any AllReduce must move at least 2S(N−1)/N per rank (information-theoretic lower bound), and ring AllReduce hits it. Other topologies (tree, butterfly) trade against this — see below.

Animated · ring AllReduce on N=4 ranks, step by step

Four ranks arranged in a ring. The tensor is split into 4 chunks. Each rank's row shows what it holds; rows light up as chunks accumulate sums. Press step 6 times: three ReduceScatter steps fold the chunks down to one fully-summed chunk per rank, then three AllGather steps redistribute. By the end, every rank holds the entire summed tensor. The counter shows bytes moved per rank — which never exceeds 2(N-1)/N · S.

Ring AllReduce timeline · click step
Ranks 0..N-1 on a circle, each holding N chunks initially. Step advances by one of the 2(N-1) ring exchanges. Highlighted ring edge = the in-flight message. Fully-summed chunks are filled solid; partial sums are striped.
phase
step
bytes sent (per rank)
% of 2S(N-1)/N

Time, not just bytes

Bandwidth says how fast you can stream data once a connection is open. Latency says how long it takes to open one. Real cost is a mix:

time  =  α · (number_of_messages)  +  β · (total_bytes)

For a ring AllReduce of size S with bandwidth BW = 1/β and per-message latency α:

T_ring  =  2(N−1) · α  +  2(N−1)/N · S / BW

Two regimes pop out:

The crossover sets the "bucket size" you'll see in PyTorch DDP and similar: bundle gradients until the bucket is big enough that β · S dominates α · N. Default in PyTorch DDP is 25 MB.

Hierarchical AllReduce — the production trick

Inside a node, NVLink + NVSwitch give each GPU ~900 GB/s of aggregate bidirectional bandwidth (the per-pair share depends on how much other traffic is using the switch). Between nodes, Infiniband gives one IB port ~50 GB/s. Asymmetry of ~18× per link. A ring that spans both intra- and inter-node bandwidth is governed by the slow link.

Hierarchical AllReduce respects the asymmetry:

  1. Intra-node Reduce. All 8 GPUs in a node ring-reduce their copies into one rank's local sum. Uses NVLink only.
  2. Inter-node AllReduce. The chosen rank from each node participates in a ring across nodes (over IB). One participant per node, so IB sees N_nodes ranks, not N_total.
  3. Intra-node Broadcast. The chosen rank propagates the global sum back to its 7 neighbours over NVLink.

Total inter-node bytes per node: 2 · S · (N_nodes - 1) / N_nodes, instead of 2 · S · (N - 1) / N for flat ring. With 8 GPUs per node and 64 nodes, that's a ~8× reduction in IB traffic for one node. NCCL picks this automatically when it detects the topology. (It also has tree, double-binary tree, and several others — but ring vs hierarchical is the core mental model.)

Animated · ring vs tree vs hierarchical, side by side

Three topologies racing the same payload over 8 ranks (or 16, or 32). The ring travels along a circle; the tree fans in and out in log N levels; the hierarchical does two intra-node reductions and one inter-node ring. Press play. The counter shows messages-completed per algorithm.

Three algorithms, one payload · watch the messages fly
Latency-dominated regime: ring sends 2(N-1) sequential messages; tree completes in 2·log₂(N) hops; hierarchical splits intra and inter. The first to finish is the algorithm NCCL would pick at this payload size.
ring progress
tree progress
hierarchical progress
winner

What changes in TP, FSDP, MoE

Naming the collective each parallelism uses is the single most useful mnemonic in this series. From the table at the top:

Interactive · ring vs tree vs hierarchical

Three collectives, same payload. Watch the total-time bars change as you scale ranks, message size, and the latency/bandwidth ratio. The crossover where tree beats ring is the message size at which gradient bucketing makes sense; the gap between flat ring and hierarchical is the value of "stay on one node when you can."

AllReduce cost · ring vs hierarchical vs tree
Cost in microseconds for a single AllReduce of S bytes on N ranks, parametrised by α (per-message latency, μs) and β = 1/BW (s/byte). Try N=512, S=14 GB (a 7B model's grads in bf16) — flat-ring is ~280 ms, hierarchical ~80 ms.
flat ring
tree
hierarchical
winner

The two-line vocabulary check

If you hear …… what's actually happening
"All gradients are summed across ranks"AllReduce
"Reconstruct the full weight from shards"AllGather
"Sum the grads, but only keep your slice"ReduceScatter
"Each token goes to its expert"AllToAll
"Activation crosses a pipeline stage"Send/Recv (point-to-point)
"Broadcast the initial weights"Broadcast (rank 0 → all)
Takeaway
Per-rank AllReduce cost is 2(N−1)/N · S / BW: bandwidth-independent of N for large S. Tree wins for small messages (latency). Hierarchical wins when intra- and inter-node bandwidth differ — which they always do. Bucketing exists to push small AllReduces into the bandwidth-bound regime where ring is optimal.