system_ml / 05 · FSDP / ZeRO lesson 5 / 19

FSDP / ZeRO — sharding what DDP replicates

Take the redundancy out of DDP. The three things every rank kept identical copies of — optimizer state, gradients, parameters — get sharded one at a time. Each stage trades a little more communication for a lot less memory.

The observation that starts it all

DDP makes N copies of optimizer state, N copies of gradients, and N copies of parameters. After the gradient AllReduce, all N ranks compute the same optimizer step on the same averaged gradient and arrive at the same new weights. That's redundant compute on identical data. If rank 0 does the optimizer step for the first 1/N of the weights, rank 1 does the next 1/N, …, every rank still ends up with its own slice of new weights — and if we agree to share those slices on demand, we never had to hold all N copies in the first place.

That's the entire idea. The "ZeRO" name is from Rajbhandari et al. (DeepSpeed, 2019); "FSDP" (PyTorch's name for the ZeRO-3 mechanic, plus a layer of usability) became the canonical implementation in production.

The three stages

State pieces, in increasing order of "how often you need them":

The optimizer state is the easiest to shard because you only need your shard to do your step. Gradients are next — you need the full grad to AllReduce, but if you ReduceScatter instead, every rank ends up with the slice of the summed gradient it needs. Parameters are the hardest, because every forward pass actually needs the full layer's weights — so we have to "reconstruct" the layer just-in-time and free it right after.

StageParamGradOptMemory / rankComm / step vs DDP
0 (DDP)fullfullfull2P + 2P + 12P = 16P1.0× (one AllReduce)
1fullfull1/N2P + 2P + 12P/N~1.0× (still one AllReduce)
2full1/N1/N2P + 2P/N + 12P/N~1.0× (ReduceScatter + Broadcast ≡ AllReduce)
3 (FSDP)1/N1/N1/N(2P + 2P + 12P)/N = 16P/N~1.5× (extra AllGather of params per layer)

P = number of parameters. Bytes-per-param: 2 (bf16 weight) + 2 (bf16 grad) + 4 (fp32 master) + 8 (Adam m,v) = 16 in mixed precision (we packaged "master + m + v" as the 12-byte "opt state" above).

FSDP's mechanic — what one forward step actually does

For each layer, in order:

  1. AllGather the layer's parameters. Every rank holds 1/N of the layer's weights; before forward, they cooperatively reconstruct the full layer. Cost: |layer| · (N-1)/N bytes per rank, asymptotically |layer|.
  2. Compute the layer's forward. Local computation, output saved for backward (or recomputed via activation checkpointing).
  3. Free the AllGathered weights. We're done with the full layer; throw it away and reclaim HBM. We keep our 1/N shard.

Backward, in reverse layer order:

  1. AllGather the layer's parameters again. We need the weights to compute the input-side gradient (the chain rule pulls the weights through the operation). Same cost as above.
  2. Compute the layer's backward. Produces a gradient tensor of the same shape as the full layer.
  3. ReduceScatter the gradient. Every rank ends up with the summed gradient of its 1/N slice. Same cost as AllReduce's two halves combined.
  4. Free the AllGathered weights. Same as in forward.

So per layer: 2 AllGathers + 1 ReduceScatter. Versus DDP's 1 AllReduce (which is itself ≡ AllGather + ReduceScatter). FSDP is roughly 1.5× the communication. In exchange, each rank's memory drops from 16P to 16P/N.

Animated · one FSDP layer step, scrubbable

Below is a frame-by-frame view of what each of N=4 ranks holds across one forward+backward of a 4-layer transformer. The bottom strip is the per-rank HBM curve over time: notice the spikes at AllGather and the troughs after Free. The flat baseline is the permanent shard each rank keeps; everything above it is the AllGathered, transient full layer.

FSDP layer lifecycle · scrub the timeline
Each rank holds 1/4 of every layer's weights permanently. To compute layer L, ranks cooperatively AllGather the full L, compute, then Free. Watch the HBM curve.
phase
layer
collective
HBM / rank

2D · the four ZeRO stages laid out across ranks

The same model, the same N=8 ranks, four different ways to slice its training state. Toggle a stage. Solid squares = stored on this rank; faded squares = held by some other rank. The total colored area on a given row is what that rank actually pays for in HBM.

ZeRO shard explorer · pick a stage, see who holds what
Three rows: parameters (blue), gradients (orange), optimizer state (green, 6× wider per param because 4B master + 4B m + 4B v). Click a stage button.
per-rank state (× P bytes)
params
grads
optimizer
Why is it only 1.5×, not 2× or 3×?
Two reasons. First, the AllReduce that DDP "saves" is itself decomposable as AllGather + ReduceScatter — same bytes total. Second, FSDP's extra AllGather (the second one, during backward) only adds one additional collective; the other two are ZeRO-2-and-below's "normal" cost rewritten as primitives.

What the unit of sharding is

FSDP doesn't shard the model parameter-by-parameter. It shards at the granularity of a FSDP unit, which is typically a transformer block. The unit is the thing that gets AllGathered and freed as a whole. Choosing the unit is a knob:

HSDP — the production compromise

Pure FSDP across all N GPUs sends collectives across the slow inter-node link constantly. Recall from lesson 03: per-layer AllGather at 50 GB/s inter-node for a 7B model is ~5–10 ms; do 80 layers × 3 collectives = ~1.5 s. Useless.

HSDP (Hybrid Sharded Data Parallel) shards within a node and replicates across nodes:

HSDP gets ~80% of FSDP's memory savings with ~1.1× DDP's communication cost. It is the actual production layout used for most ≤ ~30B models.

3D · the HSDP topology, isometric

Four nodes, eight GPUs each — the canonical 32-GPU box. The eight GPUs within a node form an FSDP group (they cooperatively hold one full copy of the model, sharded 1/8). The four nodes form a DDP group (each node holds a complete replica). Click a GPU to highlight its FSDP group and its DDP peers across the other nodes; click a different shard color to walk through all eight slices.

HSDP topology · click a GPU
Within a node: NVLink (gold edges) connects the 8 FSDP shard-mates. Across nodes: IB (blue) connects DDP replicas at corresponding shard positions. The IB AllReduce only carries 1/8 of the gradient because each rank only owns 1/8.
selected GPU
— click one —
FSDP group size
DDP group size
grads/step on IB

Sharding vs activation checkpointing — orthogonal axes

Two memory-saving techniques, often confused, addressing two different sources:

TechniqueWhat it savesWhat it costs
FSDP / ZeROParameter + gradient + optimizer state~1.5× communication
Activation checkpointingActivations stored for backward~33% extra compute (recompute forward during backward)

Production training of any large model uses both. FSDP-3 + activation checkpointing brings memory per rank to ~16P/N + small, which is what's been letting Llama-3 and friends fit on H100 clusters.

Compute–memory tradeoff, exactly

The right way to think about ZeRO stages is as a knob that smoothly converts memory pressure into communication pressure. Until the model fits, increasing ZeRO stage is free (the comm bound isn't hit). Once it fits, increasing the stage further is wasted communication — you'd be better off using the freed memory for a bigger batch or longer sequence.

Rough decision tree:

  1. Single-GPU training state fits → use DDP. No ZeRO needed.
  2. Optimizer state pushes you over → ZeRO-1.
  3. Gradients also too big → ZeRO-2.
  4. Parameters themselves too big → ZeRO-3 / FSDP.
  5. Activations also too big → add activation checkpointing.
  6. A single transformer layer too big to AllGather even temporarily → composing with TP (lesson 06).

Interactive · ZeRO stage vs memory vs comm

Set the model size, optimizer, world size N. The widget plots memory per rank and per-step communication across the four ZeRO stages, plus HSDP. The "you are here" marker on the memory axis is your chosen GPU's HBM budget. Watch the bars cross under the budget as you move up the stages.

ZeRO stages · memory and comm per rank
Memory plot is linear in P/N · 16 bytes. Comm plot is per-step bytes touched, in units of |params|.
memory @ ZeRO-3
memory @ HSDP
comm @ ZeRO-3 (× P)
comm @ HSDP (× P)
Takeaway
ZeRO is a knob: at each click up, one more redundancy is shaved off in exchange for slightly more communication. FSDP (ZeRO-3) gives you memory reduction at ~1.5× DDP's comm. HSDP gives ~80% of that memory win at near-DDP comm. Pure FSDP across nodes is wasteful; HSDP is the production default for everything that's not tensor-parallel.