system_ml / 06 · tensor parallel lesson 6 / 19

Tensor parallel — sharding inside a single layer

When a single transformer block is too big to fit on one GPU, FSDP's just-in-time gather no longer works. The fix is Megatron-LM's column/row trick: split the weights along carefully chosen axes so the math falls out as one local matmul plus one AllReduce per block.

What FSDP can't do

FSDP's mechanic depends on temporarily reconstructing a full layer on one rank during forward and freeing it after. For a 70B model split into 80 layers, each layer is ~875M parameters — about 1.75 GB in bf16. Eight ranks holding 1/8 each is 220 MB each. AllGathered up to 1.75 GB on each rank: easily fits.

But: for a 540B model with the same depth, each layer is ~6.75 GB. Or a Llama-style 405B with shorter depth. At some point, even one layer is bigger than one GPU. FSDP needs the full layer "transiently" in HBM; if you can't temporarily fit it, FSDP breaks. Tensor parallel is the response.

The Megatron-LM MLP split

An MLP block does y = σ(xA) B with weights A ∈ ℝ^{d × 4d} and B ∈ ℝ^{4d × d}. We want to shard A and B across N ranks. The trick is to choose which axis of each matrix to shard:

So one MLP block becomes: column-parallel A, elementwise activation, row-parallel B, AllReduce. One AllReduce per MLP block per forward pass. Backward does symmetric work — one AllReduce in the other direction during the gradient pass.

Animated · the column/row dance on 4 GPUs

Below is the same MLP block — Y = σ(X · A) · B — animated across N = 4 ranks. Watch how X stays replicated, A is sharded column-wise (each rank holds a vertical slice), the intermediate σ(XA_i) lives only on its rank, then B is sharded row-wise so each rank produces a partial output of the full shape — and an AllReduce sums them. Step through to see exactly what bytes move at which step.

Column/row TP · 4 GPUs cooperating on one MLP
Five phases: X replicatedlocal XA_ielementwise σpartial Y_i = σ(XA_i)·B_iAllReduce. The AllReduce is the only communication. Backward mirrors this.
phase
bytes on wire
collective
per-rank held

2D · which axis to shard, and which collective fires

The choice of axis per matrix is what makes the chain of two matmuls only need one AllReduce. Toggle the two sharding policies below. The smart one (column-A → row-B) needs one AllReduce at the end. The naive one (row-A → column-B, or any mismatched pair) needs two — one to gather the intermediate, one for the output. Each blue/orange tile is a matrix slice; the arrows are the collectives that fire.

Sharding policy explorer · how many AllReduces?
Two matrices in series: A ∈ ℝ^{d×4d} and B ∈ ℝ^{4d×d}. Color = ownership by rank. The arrows below the matrices are the collectives needed between matmuls.
A sharded by
B sharded by
AllReduces (fwd)
verdict

3D · 4-GPU tensor-parallel layout, isometric

Four GPUs in a row, each a cube. The top face of each cube shows that GPU's slice of A (column-sharded — colored by column index); the front face shows its slice of B (row-sharded — colored by row index). The gold edges are the AllReduces that fire twice per layer (once after attention, once after the MLP). Click a GPU to see exactly which slices it owns. Rotate to spin the view.

TP topology · click a GPU to see its slices
Each GPU permanently holds one column of A and one row of B. Gold edges = AllReduce ring (NVLink). Per transformer layer: 2 AllReduces in fwd (attention out, MLP out) + 2 in bwd = 4 total.
selected GPU
— click one —
A slice
B slice
AllReduces / layer
4
column-parallel A · no comm row-parallel B · AllReduce after x · A_i (per-rank) shape: (B, 4d/N) — different columns per rank σ(·) elementwise · local y_i = σ(xA_i) · B_i shape: (B, d) — partial sum across hidden axis y = AllReduce(y_i) ← cost

The attention split

Multi-head attention is naturally shardable across heads. Each rank holds h/N Q heads and the corresponding K, V heads. Forward:

  1. Compute Q_i, K_i, V_i projections. Like the MLP's column-parallel A: shard the QKV projection matrix column-wise so rank i holds the columns producing only its heads. No comm.
  2. Local attention. Each rank computes softmax(Q_i K_i^T / √d_k) V_i. Heads are independent, so this is rank-local.
  3. Output projection. Like the MLP's row-parallel B: shard the output projection row-wise so rank i contributes only the partial sum from its heads. AllReduce sums across ranks.

So attention costs one AllReduce per block per forward too — same pattern as MLP.

GQA / MQA wrinkle
Grouped-query attention uses fewer KV heads than Q heads (e.g. Llama-3-70B has 64 Q heads but only 8 KV heads). If you try to TP=16, you can't evenly shard 8 KV heads across 16 ranks. Common solution: replicate the KV projections across pairs of TP ranks. This means TP > h_kv is wasteful — the effective max TP is bounded by the KV-head count. Modern models often deliberately keep h_kv ≥ 8 to support TP=8.

The cost, explicitly

Per transformer layer, per forward pass: 2 AllReduces (one for attention output, one for MLP output). Per layer per backward pass: 2 more (the backward of an AllReduce is itself a no-op, but the backward of column-parallel A needs an AllReduce of input gradients — symmetric to forward). So 4 AllReduces per layer per training step.

Message size of each AllReduce: the output of the block, shape (B, T, d), in bf16. That's B · T · d · 2 bytes.

For Llama-3-70B at B=4, T=4096, d=8192: each AllReduce is ~270 MB. With 80 layers and 4 AllReduces per layer = 320 AllReduces × 270 MB = ~86 GB total. At intra-node 900 GB/s effective NVLink, that's ~95 ms of communication per step. At inter-node 50 GB/s, that's ~1.7 seconds — bigger than the step itself, totally unworkable. This is the calculation that pins TP to one node.

Sequence parallel — what TP can't reach

TP shards along the feature axis. But the inputs to LayerNorm and the dropout immediately around the block are still full (B, T, d) on every rank. For long sequences these activations alone can exceed HBM. Sequence parallel (Korthikanti et al. 2022) shards these along the sequence axis, doing an AllGather right before the next TP block (which needs the full feature dim) and a ReduceScatter right after. Same total communication as plain TP, plus the activation memory drops by . Lesson 08 covers this and its longer-context cousin (context parallel).

What changes for inference

The same TP split applies at inference: AllReduce twice per layer. The cost looks the same but the value is different — at decode time, each forward pass produces just one token but reads the entire weight matrix from HBM. TP=8 spreads that read across 8 GPUs' HBM bandwidth, dividing the per-token time by ~8 (the AllReduce overhead is small relative to the memory read on long context). This is the lesson-11 finding: TP wins on inference latency, not throughput.

The four "stay on one node" facts

  1. TP fires 2 AllReduces per layer per forward pass.
  2. Each AllReduce is roughly the size of the layer's output activation.
  3. That's hundreds of AllReduces per training step.
  4. The math only works if those AllReduces run on NVLink, not IB.

This is also why TP is the innermost axis when composing parallelism (lesson 10). The fastest interconnect goes to the heaviest-communicating axis.

Interactive · scaling TP, fabric-limited

Pick model dims, sequence, batch, TP size, and the interconnect bandwidth. The widget plots per-layer compute vs per-layer comm and shows whether the AllReduce can hide behind the local matmul (the AllReduce can actually overlap with the next kernel's launch in NCCL, but only marginally — the kernel boundaries usually fence it). Try TP=8 at NVLink vs TP=8 at IB; the second is a disaster.

TP per-block timeline · compute vs AllReduce
Each block is (matmul + AllReduce). Compute scales as B·T·d²/TP; AllReduce scales as B·T·d. Watch the AllReduce fraction climb as you cross from intra-node to inter-node bandwidth.
compute per block
AllReduce per block
comm fraction
verdict
Takeaway
TP shards weights inside a layer. The cost is two AllReduces per block per forward — large and frequent, only affordable on NVLink. It's the right answer when one layer is too big to gather on one GPU; it's the wrong answer for everything else. Below ~30B params, FSDP almost always beats TP for the same memory result with less complexity.