Why distributed at all?
Three failure modes — memory, compute, throughput — each with its own response. Most of the complexity in this series exists to address one of them.
The honest framing
"Distributed" is not an architectural goal; it is a tax you pay because something doesn't fit. Whenever a teammate says "let's go distributed", three questions should fire before any code is touched:
- Does the training state fit in one GPU's HBM? Weights + gradients + optimizer state + activations. If no — you have a memory wall, and your response is to shard those things (FSDP / ZeRO, lesson 05).
- If it fits, can one GPU finish training in tolerable wall-clock time? If no — you have a throughput wall, and your response is to replicate the model and split the data across workers (DDP / FSDP, lesson 04).
- For inference: can one replica meet the QPS / latency target? If no — you have a serving wall, and your response is replication plus, sometimes, intra-replica sharding (lesson 11).
A model that fits, trains in reasonable time, and serves at moderate QPS doesn't need any of this. Distributed engineering is genuinely expensive — NCCL deadlocks, opaque hangs, version skew, straggling ranks, silent data corruption at scale. Pay only for one of the three reasons above.
Walking the first wall — the bytes per parameter
The memory wall is the cleanest one to reason about, because we can count bytes. For a parameter θ stored in bf16 (2 bytes), under Adam, in the standard mixed-precision recipe:
Sixteen bytes is the headline number. For a 70B-parameter model that's 1.12 TB of training state — before we count activations, before we count the KV-cache-like intermediates that backward needs. On an 80 GB H100, that's 14× too big to fit anywhere on the device.
Activations are the other axis, and they scale with sequence length and batch, not just with parameters. For a transformer with hidden size d, sequence length T, batch B, and L layers, the activations checkpoint-able for backward are roughly:
where k is some small multiplier (typically 12–20 for a vanilla layer; activation checkpointing cuts it dramatically at the cost of recomputing forward during backward). For Llama-70B-ish numbers (d=8192, L=80) and a single sequence at T=8192: roughly ~130–200 GB per sequence without checkpointing, dropping to ~20 GB with aggressive activation checkpointing — on top of the 1.1 TB of state. The point is not the exact number but its shape: activations grow with batch, state grows with model size, and they grow independently.
Animated · the 16-byte stack vs the HBM ceiling
The bytes-per-param stack literally stacks: bf16 weight + bf16 grad + fp32 master + Adam m + Adam v. Multiply by parameter count and the column grows. Slide it past the H100 ceiling, then the H200, then the B200. Watch the column turn red the moment it overshoots — that's the memory wall as a height. The animation phase shows each byte-category appearing in sequence so you can see which 4 bytes you'd save with Adafactor, or with mixed-precision tricks.
Walking the second wall — the time per step
Even if the model fit in one GPU's HBM, training a 70B model on 15 trillion tokens at single-GPU speed is infeasible. The arithmetic is brutal but useful to keep in your head:
The factor of 6 comes from the standard scaling-laws accounting (Kaplan et al. 2020; also used in Chinchilla / Hoffmann et al. 2022): 2 FLOPs per parameter per token for the forward pass, plus 4 for the backward pass — split as 2 for the input-gradient pass and 2 for the weight-gradient pass. For 70B × 15T tokens, that's 6.3 × 10²⁴ FLOPs. An H100 SXM does about 1 × 10¹⁵ bf16 FLOPs/s peak; in practice you achieve ~40% of that (MFU, "model FLOP utilization") because of bandwidth-bound layers and overhead. One GPU at 0.4 PFLOPS:
500 years. Even if memory were free, time isn't. The whole point of data parallel (lesson 04) is to drop this number by a factor of N by running N forward passes simultaneously on different data — the elapsed time falls to ~6 months at 1024 GPUs. The cost is one AllReduce of all gradients per step. We will spend lesson 02 making sure that AllReduce is essentially free.
2D · three walls in one picture
The three walls — memory, compute, throughput — are three different scaling lines that each get hit at different model sizes. Pick a model size with the slider; the three bars on the left show how full each budget is. The animated "ball" on each track shows the linear scaling. When a bar fills past 100%, that wall is hit and turns red.
The roofline — the one diagram for "what's bound by what"
Every kernel on a GPU is bound by one of two things: how fast you can stream bytes from HBM, or how fast the tensor cores can do FLOPs on bytes already in SRAM. The roofline plot makes the choice visible:
Two of the most important distributed-systems consequences fall out of this picture:
- Decode is memory-bound, prefill is compute-bound. Per-token decode reads the entire weight matrix from HBM and does one matmul-vector. Arithmetic intensity is roughly batch · 1. So a per-replica batch size of 1 lives far down the orange slope — far below peak. Bigger batches climb the slope; eventually you hit the ridge and become compute-bound. This single fact drives PD-disaggregation, continuous batching, speculative decoding, and almost every inference optimization. (Lesson 12, lesson 11.)
- Training prefers high arithmetic intensity, so big batches win — until the AllReduce stops scaling. Training a transformer layer at batch 1024 on a long sequence comfortably sits above the ridge. Distributed training keeps that batch large by adding ranks: each rank takes B/N of a batch, the per-rank work stays comfortably compute-bound, and the only added cost is gradient sync. This is why DDP works at all.
The accounting habit
The most important skill in this series is doing back-of-envelope estimates of bytes and FLOPs before writing any code. Practising it on a few benchmark numbers:
| Quantity | Symbol / formula | 70B model (B=1, T=8k) |
|---|---|---|
| Param state (mixed-prec Adam) | 16 · params | 1.12 TB |
| Forward activations (no checkpointing) | ~B · T · L · d · 14 bytes | ~150 GB |
| Forward activations (checkpointed) | ~B · T · L · 2 bytes | ~2.6 GB |
| One step FLOPs | 6 · params · B · T | ~3.4 × 10¹⁵ |
| One step time at 1 PFLOPS @ 40% MFU | FLOPs / (0.4 · peak) | ~8.4 ms (per GPU, fictionally) |
| Grad AllReduce volume per step | ~2 · 2 · params | 280 GB / rank, ring-asymptotic |
| Grad AllReduce time at 100 GB/s ring BW | vol / BW | 2.8 s if not overlapped |
That last line is the punchline: a 2.8s AllReduce against an 8ms step is ~350× too slow. So the AllReduce has to be either small (FSDP cuts it; lesson 05), local (intra-node NVLink; lesson 03), or hidden behind the compute (overlap; lesson 04). All three are real strategies, and all three are tested in the lessons ahead.
3D · the regime cube
Plot a model in (memory, compute, bandwidth) space. Each axis is "how much of that resource does one rank need". Models cluster into regimes: small models live near the origin (any GPU), 7B models press against the memory axis, frontier 405B models pin all three. The dot color tells you which axis is the binding wall. Rotate the cube to read off the depth.
Interactive · feel the memory wall
Slide the model size, the precision, the optimizer, and the batch around. Watch the bar chart show which buckets fit on which GPU. The dotted line is your chosen GPU's HBM budget. The lesson is in which slider hits the wall first — and that wall is what each subsequent lesson is going to dismantle.
Where each wall sends you
| Wall | Symptom | First response | Lessons |
|---|---|---|---|
| Memory (state) | OOM at step 0 | FSDP / ZeRO-3 → activation checkpointing → TP | 05, then 06 |
| Memory (activations) | OOM at longer sequence | Activation checkpointing → SP → CP | 05, 08 |
| Memory (KV cache) | OOM serving long context | GQA → KV quantization → paged KV | see vLLM/09, vLLM/02 |
| Time (training) | One-step time × steps > deadline | DP → FSDP-HSDP → 3D parallelism | 04, 05, 10 |
| Latency (inference) | TTFT or ITL miss target | TP per replica → speculative decode → PD-disagg | 11, 12 |
| Throughput (inference) | QPS shortfall | Replicate → continuous batching → APC | 11, see vLLM/04 |