FSDP / ZeRO — sharding what DDP replicates
Take the redundancy out of DDP. The three things every rank kept identical copies of — optimizer state, gradients, parameters — get sharded one at a time. Each stage trades a little more communication for a lot less memory.
The observation that starts it all
DDP makes N copies of optimizer state, N copies of gradients, and N copies of parameters. After the gradient AllReduce, all N ranks compute the same optimizer step on the same averaged gradient and arrive at the same new weights. That's redundant compute on identical data. If rank 0 does the optimizer step for the first 1/N of the weights, rank 1 does the next 1/N, …, every rank still ends up with its own slice of new weights — and if we agree to share those slices on demand, we never had to hold all N copies in the first place.
That's the entire idea. The "ZeRO" name is from Rajbhandari et al. (DeepSpeed, 2019); "FSDP" (PyTorch's name for the ZeRO-3 mechanic, plus a layer of usability) became the canonical implementation in production.
The three stages
State pieces, in increasing order of "how often you need them":
- Optimizer state (fp32 master weights + Adam m/v) — touched only at the optimizer step, once per training step.
- Gradients — touched at backward, once per training step.
- Parameters — touched at every forward and backward, once per layer.
The optimizer state is the easiest to shard because you only need your shard to do your step. Gradients are next — you need the full grad to AllReduce, but if you ReduceScatter instead, every rank ends up with the slice of the summed gradient it needs. Parameters are the hardest, because every forward pass actually needs the full layer's weights — so we have to "reconstruct" the layer just-in-time and free it right after.
| Stage | Param | Grad | Opt | Memory / rank | Comm / step vs DDP |
|---|---|---|---|---|---|
| 0 (DDP) | full | full | full | 2P + 2P + 12P = 16P | 1.0× (one AllReduce) |
| 1 | full | full | 1/N | 2P + 2P + 12P/N | ~1.0× (still one AllReduce) |
| 2 | full | 1/N | 1/N | 2P + 2P/N + 12P/N | ~1.0× (ReduceScatter + Broadcast ≡ AllReduce) |
| 3 (FSDP) | 1/N | 1/N | 1/N | (2P + 2P + 12P)/N = 16P/N | ~1.5× (extra AllGather of params per layer) |
P = number of parameters. Bytes-per-param: 2 (bf16 weight) + 2 (bf16 grad) + 4 (fp32 master) + 8 (Adam m,v) = 16 in mixed precision (we packaged "master + m + v" as the 12-byte "opt state" above).
FSDP's mechanic — what one forward step actually does
For each layer, in order:
- AllGather the layer's parameters. Every rank holds 1/N of the layer's weights; before forward, they cooperatively reconstruct the full layer. Cost: |layer| · (N-1)/N bytes per rank, asymptotically |layer|.
- Compute the layer's forward. Local computation, output saved for backward (or recomputed via activation checkpointing).
- Free the AllGathered weights. We're done with the full layer; throw it away and reclaim HBM. We keep our 1/N shard.
Backward, in reverse layer order:
- AllGather the layer's parameters again. We need the weights to compute the input-side gradient (the chain rule pulls the weights through the operation). Same cost as above.
- Compute the layer's backward. Produces a gradient tensor of the same shape as the full layer.
- ReduceScatter the gradient. Every rank ends up with the summed gradient of its 1/N slice. Same cost as AllReduce's two halves combined.
- Free the AllGathered weights. Same as in forward.
So per layer: 2 AllGathers + 1 ReduceScatter. Versus DDP's 1 AllReduce (which is itself ≡ AllGather + ReduceScatter). FSDP is roughly 1.5× the communication. In exchange, each rank's memory drops from 16P to 16P/N.
Animated · one FSDP layer step, scrubbable
Below is a frame-by-frame view of what each of N=4 ranks holds across one forward+backward of a 4-layer transformer. The bottom strip is the per-rank HBM curve over time: notice the spikes at AllGather and the troughs after Free. The flat baseline is the permanent shard each rank keeps; everything above it is the AllGathered, transient full layer.
2D · the four ZeRO stages laid out across ranks
The same model, the same N=8 ranks, four different ways to slice its training state. Toggle a stage. Solid squares = stored on this rank; faded squares = held by some other rank. The total colored area on a given row is what that rank actually pays for in HBM.
What the unit of sharding is
FSDP doesn't shard the model parameter-by-parameter. It shards at the granularity of a FSDP unit, which is typically a transformer block. The unit is the thing that gets AllGathered and freed as a whole. Choosing the unit is a knob:
- Bigger units (e.g. the whole model) → fewer collectives, but each unit must fit on one rank temporarily when AllGathered. Defeats the purpose of FSDP for very large models.
- Smaller units (every
nn.Linear) → many tiny collectives, latency-dominated, low effective bandwidth. - Per transformer block (the production default) → good balance. The unit is hundreds of MB to a few GB, large enough that AllGather is bandwidth-bound, small enough that you can free it after forward.
HSDP — the production compromise
Pure FSDP across all N GPUs sends collectives across the slow inter-node link constantly. Recall from lesson 03: per-layer AllGather at 50 GB/s inter-node for a 7B model is ~5–10 ms; do 80 layers × 3 collectives = ~1.5 s. Useless.
HSDP (Hybrid Sharded Data Parallel) shards within a node and replicates across nodes:
- Intra-node: FSDP across the 8 GPUs of a node. AllGathers use NVLink at 900 GB/s. Memory per GPU is 16P/8 = 2P.
- Inter-node: DDP replication across nodes. One AllReduce of gradients per step (across the shards — each rank only holds 1/8 of the gradient, so the inter-node AllReduce is 8× smaller too).
HSDP gets ~80% of FSDP's memory savings with ~1.1× DDP's communication cost. It is the actual production layout used for most ≤ ~30B models.
3D · the HSDP topology, isometric
Four nodes, eight GPUs each — the canonical 32-GPU box. The eight GPUs within a node form an FSDP group (they cooperatively hold one full copy of the model, sharded 1/8). The four nodes form a DDP group (each node holds a complete replica). Click a GPU to highlight its FSDP group and its DDP peers across the other nodes; click a different shard color to walk through all eight slices.
Sharding vs activation checkpointing — orthogonal axes
Two memory-saving techniques, often confused, addressing two different sources:
| Technique | What it saves | What it costs |
|---|---|---|
| FSDP / ZeRO | Parameter + gradient + optimizer state | ~1.5× communication |
| Activation checkpointing | Activations stored for backward | ~33% extra compute (recompute forward during backward) |
Production training of any large model uses both. FSDP-3 + activation checkpointing brings memory per rank to ~16P/N + small, which is what's been letting Llama-3 and friends fit on H100 clusters.
Compute–memory tradeoff, exactly
The right way to think about ZeRO stages is as a knob that smoothly converts memory pressure into communication pressure. Until the model fits, increasing ZeRO stage is free (the comm bound isn't hit). Once it fits, increasing the stage further is wasted communication — you'd be better off using the freed memory for a bigger batch or longer sequence.
Rough decision tree:
- Single-GPU training state fits → use DDP. No ZeRO needed.
- Optimizer state pushes you over → ZeRO-1.
- Gradients also too big → ZeRO-2.
- Parameters themselves too big → ZeRO-3 / FSDP.
- Activations also too big → add activation checkpointing.
- A single transformer layer too big to AllGather even temporarily → composing with TP (lesson 06).
Interactive · ZeRO stage vs memory vs comm
Set the model size, optimizer, world size N. The widget plots memory per rank and per-step communication across the four ZeRO stages, plus HSDP. The "you are here" marker on the memory axis is your chosen GPU's HBM budget. Watch the bars cross under the budget as you move up the stages.