Collectives — the seven primitives
Every distributed-training cost in this series is a sum of collective costs. This lesson is the vocabulary and the ring-AllReduce derivation that you'll reuse for the rest of the series.
The setting
N processes (one per GPU, by convention). Each one holds a buffer; they need to combine their buffers somehow. There are essentially seven shapes of "combine", and learning them once removes 80% of the magic from FSDP, TP, EP, and pipeline parallel. Vocabulary:
| Collective | Input on rank i | Output on rank i | Where it appears |
|---|---|---|---|
| Broadcast | (only on root) | Root's tensor | Init weights, checkpoint reload |
| Reduce | xi | Σ xj on root only | Log aggregation |
| AllReduce | xi | Σ xj on every rank | DDP gradient sync, TP output |
| Gather | xi | [x0, …, xN-1] on root | Log gather |
| AllGather | xi (a shard) | Full concatenation on every rank | FSDP forward (reconstruct full layer) |
| ReduceScatter | xi (full) | i-th piece of Σ xj | FSDP backward (shard the gradient) |
| AllToAll | [xi,0, …, xi,N-1] | [x0,i, …, xN-1,i] | MoE routing, transpose |
AllToAll is the one that catches people. Read its row carefully: rank i begins holding one chunk for each other rank, and ends holding one chunk from each other rank. It is a transpose of the (rank × chunk) matrix. Cost-wise it sends N-1 different messages instead of one; latency-bound rather than bandwidth-bound. Lesson 09 (MoE) is where it bites.
2D · the seven primitives, side by side
Each primitive is a different rearrangement of "rank × chunk" tiles. The animation below shows the input layout on the left, the output on the right, and the per-rank moves as colored arrows. Tab through them. Notice how AllGather and ReduceScatter are exact complements — they're two halves of one AllReduce.
The identity that explains FSDP
Stare at it. ReduceScatter reduces and then leaves each rank holding only its own slice of the sum. AllGather then redistributes those slices so every rank has the whole sum again. Two steps, same final answer as a one-step AllReduce.
Why care? Because if you only need your slice of the sum (because that's the slice you'll use next anyway — for instance because you only own that slice of the optimizer state, as in ZeRO-1), you can stop after ReduceScatter. You've now done half of an AllReduce. Inversely, if every rank starts with its own slice (because you sharded the weights, as in ZeRO-3) and you need the full thing for the next forward, you do an AllGather. The two halves of AllReduce can be charged separately by different optimizations. This is the trick lesson 05 turns into ZeRO.
Ring AllReduce — the bandwidth-optimal derivation
The naive AllReduce: send your tensor to a central root, sum, send back. Cost per rank: 2 · S bytes (one up, one down), but the root is hammered and the link to it is the bottleneck. With N ranks all sending simultaneously, the root's link receives N · S bytes — it scales linearly with cluster size. Bad.
Ring AllReduce does better. Arrange the N ranks in a logical ring: rank i sends only to rank (i+1) mod N. Each tensor is split into N chunks. Now two passes:
- ReduceScatter pass (N-1 steps). At step t, rank i sends chunk (i-t) mod N to its neighbour and receives chunk (i-t-1) mod N, which it adds to the chunk it already owns. After N-1 steps, rank i has the fully reduced version of chunk (i+1) mod N (the chunk it'll be "responsible for" in the second pass). Every other rank has the fully reduced version of their chunk.
- AllGather pass (N-1 steps). The chunk each rank now owns gets passed around the ring once. After N-1 more steps, every rank has every chunk fully reduced.
Per rank, in total: 2(N-1) messages, each of size S/N. Total bytes moved per rank:
This is the punchline of the lesson. Per-rank traffic of a ring AllReduce is independent of cluster size for large N. Adding a 257th GPU doesn't change how many bytes each existing GPU has to send. The link load is balanced across all N ring edges.
Animated · ring AllReduce on N=4 ranks, step by step
Four ranks arranged in a ring. The tensor is split into 4 chunks. Each rank's row shows what it holds; rows light up as chunks accumulate sums. Press step 6 times: three ReduceScatter steps fold the chunks down to one fully-summed chunk per rank, then three AllGather steps redistribute. By the end, every rank holds the entire summed tensor. The counter shows bytes moved per rank — which never exceeds 2(N-1)/N · S.
Time, not just bytes
Bandwidth says how fast you can stream data once a connection is open. Latency says how long it takes to open one. Real cost is a mix:
For a ring AllReduce of size S with bandwidth BW = 1/β and per-message latency α:
Two regimes pop out:
- Big tensors. The S/BW term dominates; cost is ~2S/BW regardless of N. Ring is great.
- Small tensors. The (N−1) · α latency term dominates; ring is O(N) in latency. Bad. Use a tree (latency O(log N)) or, better, batch many small reduces together (this is exactly what DDP's bucket size knob in lesson 04 controls).
The crossover sets the "bucket size" you'll see in PyTorch DDP and similar: bundle gradients until the bucket is big enough that β · S dominates α · N. Default in PyTorch DDP is 25 MB.
Hierarchical AllReduce — the production trick
Inside a node, NVLink + NVSwitch give each GPU ~900 GB/s of aggregate bidirectional bandwidth (the per-pair share depends on how much other traffic is using the switch). Between nodes, Infiniband gives one IB port ~50 GB/s. Asymmetry of ~18× per link. A ring that spans both intra- and inter-node bandwidth is governed by the slow link.
Hierarchical AllReduce respects the asymmetry:
- Intra-node Reduce. All 8 GPUs in a node ring-reduce their copies into one rank's local sum. Uses NVLink only.
- Inter-node AllReduce. The chosen rank from each node participates in a ring across nodes (over IB). One participant per node, so IB sees N_nodes ranks, not N_total.
- Intra-node Broadcast. The chosen rank propagates the global sum back to its 7 neighbours over NVLink.
Total inter-node bytes per node: 2 · S · (N_nodes - 1) / N_nodes, instead of 2 · S · (N - 1) / N for flat ring. With 8 GPUs per node and 64 nodes, that's a ~8× reduction in IB traffic for one node. NCCL picks this automatically when it detects the topology. (It also has tree, double-binary tree, and several others — but ring vs hierarchical is the core mental model.)
Animated · ring vs tree vs hierarchical, side by side
Three topologies racing the same payload over 8 ranks (or 16, or 32). The ring travels along a circle; the tree fans in and out in log N levels; the hierarchical does two intra-node reductions and one inter-node ring. Press play. The counter shows messages-completed per algorithm.
What changes in TP, FSDP, MoE
Naming the collective each parallelism uses is the single most useful mnemonic in this series. From the table at the top:
- DDP (lesson 04): one AllReduce of all gradients per step. Hidden behind backward via overlap.
- FSDP / ZeRO-3 (lesson 05): AllGather to reconstruct a full layer (forward), AllGather + ReduceScatter (backward). Same total bytes as an AllReduce, just split. Same factor-of-2 cost; the win is memory, not communication.
- Tensor parallel (lesson 06): two AllReduces per transformer layer per forward (one in attention, one in MLP). High frequency, modest size — pinned intra-node.
- Pipeline parallel (lesson 07): point-to-point Send/Recv only (not technically a "collective" — just NCCL P2P). The bubble cost dwarfs the communication cost.
- Expert parallel (lesson 09): two AllToAlls per MoE forward, two per backward. AllToAll is the latency-sensitive collective in the table.
Interactive · ring vs tree vs hierarchical
Three collectives, same payload. Watch the total-time bars change as you scale ranks, message size, and the latency/bandwidth ratio. The crossover where tree beats ring is the message size at which gradient bucketing makes sense; the gap between flat ring and hierarchical is the value of "stay on one node when you can."
The two-line vocabulary check
| If you hear … | … what's actually happening |
|---|---|
| "All gradients are summed across ranks" | AllReduce |
| "Reconstruct the full weight from shards" | AllGather |
| "Sum the grads, but only keep your slice" | ReduceScatter |
| "Each token goes to its expert" | AllToAll |
| "Activation crosses a pipeline stage" | Send/Recv (point-to-point) |
| "Broadcast the initial weights" | Broadcast (rank 0 → all) |