system_ml / 08 · sequence / context parallel lesson 8 / 19

Sequence and context parallel — sharding the time axis

TP shards the feature dim. PP shards the layer dim. Neither shards the sequence. As contexts grow past 32k tokens, the activations that scale with T (not d) become the bottleneck. Sequence parallel reclaims them cheaply; context parallel reclaims attention itself.

The activations TP doesn't reach

Inside a TP'd transformer block, the matmuls operate on (B, T, d) tensors sharded along d. After the AllReduce that closes each block, every rank has the full (B, T, d) output. Then LayerNorm, dropout, and the residual connection happen on that full tensor — TP doesn't help, because those ops don't have a "feature dim to shard."

Memory accounting per rank, around a TP block:

For B=4, T=8192, d=8192, the unsharded chunk is ~537 MB per layer. Across 80 layers, even with activation checkpointing it sums to gigabytes that don't shrink as you add TP ranks. Sequence parallel (SP, Korthikanti et al. 2022) fixes this for free.

SP — shard along the sequence axis

The trick is: at the LayerNorm/dropout boundary, shard the activation along T instead of d. Each rank holds B · T/N · d floats. To enter the next TP'd matmul (which needs the full d), AllGather along T. To exit a TP'd matmul, ReduceScatter along T.

Crucially, the combination (AllGather along T) + (ReduceScatter along T) moves exactly the same bytes as a plain AllReduce — the bandwidth cost is identical. SP doesn't add comm; it reshapes the existing AllReduce into two halves and puts the LayerNorm/dropout work inside the sharded region.

activations around a TP block · SP variant LayerNorm (B, T/N, d) AllGather along T TP matmul (B, T, d/N) TP matmul (output) ReduceScatter along T Dropout (B, T/N, d) memory shrinks by N for LayerNorm/Dropout regions · comm = same as plain AllReduce

Equivalent in cost, free in memory. SP is almost always on in production TP setups.

The attention problem at long context

SP shrinks the LayerNorm/dropout activations — but the attention matmul itself stays as O(B · T · h · d_k) activations per rank. For T = 1M tokens that's hundreds of gigabytes. Even FlashAttention's tile-by-tile streaming computes the same final values — it just avoids materialising the T×T attention matrix.

What we need is to shard the sequence for attention itself: each rank holds T/N tokens' worth of Q, K, V, and somehow computes the same full attention. The catch: each query at position i attends to all positions j ≤ i (causal), which lives on different ranks.

Ring Attention (Liu et al. 2023)

The insight: FlashAttention's online softmax accumulator (m, ℓ, O) is associative — you can extend the partial result by feeding in another chunk of K, V at a time and stably update the running max m, normaliser , and output O. So if rank i holds Q chunk i and K, V chunk i, it can:

  1. Compute its local Q_i · K_i^T contribution to attention.
  2. Send K_i, V_i to the next rank in a ring.
  3. Receive K_{i-1}, V_{i-1} from the previous rank, compute the cross contribution, fold it into the online accumulator.
  4. Repeat for N - 1 rounds: at the end every rank has attended to all K, V.

Communication: rank i sends and receives N - 1 chunks of size B · T/N · h_kv · d_k · 2 bytes total. Crucially this is concurrent with compute — chunk i+1 can be transferring while chunk i is being processed. As long as comm ≤ compute per chunk, the cost is hidden.

Why ring attention works only because of FlashAttention
Plain attention computes softmax(QKT)V — the softmax denominator requires the full row of QKT at once. Naïve sharding of K, V would force you to either materialise the full T×T matrix or recompute things. FlashAttention's online algorithm makes the softmax decomposable across K, V chunks: keep running statistics, fold in each new chunk. Ring Attention is exactly this online algorithm distributed across ranks. Lessons 03 in vllm covers the online math.

Animated · Ring Attention, one rotation at a time

Below is the same algorithm but rendered as a literal ring of N = 4 ranks. Each rank starts owning (Q_i, K_i, V_i) for its sequence chunk and an online-softmax state (m_i, ℓ_i, O_i). On each tick, every rank uses its current K/V chunk to update its accumulator, then passes K/V to its neighbour. After N - 1 rotations, every rank has seen every K/V chunk — without ever materialising the full attention matrix.

Ring Attention · scrub through one round
Solid arcs = ranks (Q stays put). Inner labels = which K/V chunk each rank is currently using. Watch the labels rotate. Every rank's accumulator updates each step.
round
rank 0 sees
comm per round
overall progress

2D · the sequence, sharded

SP and CP both shard the sequence axis, but the scope of what they protect differs. SP shards only the activations around LayerNorm/dropout — the matmul rebuilds the full sequence first. CP shards the QKV and the attention math itself. Below is a long token strip; toggle SP vs CP and watch the colored ownership change. Hover the legend bars for what each rank actually computes.

Sequence ownership map · SP vs CP
A 64-token sequence colored by rank. Rows below: per-block "what each rank holds" tableau.
mode
SP
LN/dropout act
attention act
extra comm

Animated · online-softmax accumulator

The reason Ring Attention works at all is that FlashAttention's online softmax can ingest K/V chunks sequentially and still produce the exact same numerical result as a single pass. Below is that mechanic, isolated: drag through the chunks, watch the running statistics (m, ℓ, O) rescale themselves stably as each chunk arrives.

Online softmax · chunk-by-chunk update
Each rectangle is a K/V chunk's contribution; we fold it into m (running max), (running denom), and O (running weighted sum), rescaling old values when a new max appears.
running max m
running denom ℓ
rescale factor
final probability

Causal masking gets a free ride

For causal LMs, rank i only needs to attend to K, V chunks j ≤ i. The ring still rotates K, V through all N - 1 positions (every rank still has to see every other chunk), but ~half the matmuls are skipped because the upper-triangular contributions vanish. Naïve striping leaves later ranks doing more work than earlier ranks; zigzag and striped schedules (Liu et al., Megatron-LM) reorder chunks across ranks so each one ends up with roughly equal compute. Default Megatron-LM uses striped CP.

SP vs CP — different sizes, different problems

StrategyWhat it shardsWhen it bitesComm
SP (sequence parallel)LayerNorm/dropout activationsWhenever TP is used and T > ~2kFree (replaces AllReduce with AllGather+ReduceScatter)
CP (context parallel)Attention computation itselfT > ~32k1 ring rotation of K,V per attention block; comm volume ~B · T · 2 · h_kv · d_k per layer

SP is on by default whenever TP > 1 in modern frameworks (Megatron-LM, NeMo). CP is opt-in for long-context pretraining and fine-tuning. The two compose: a rank can be both TP=8 and CP=4, giving 32 ranks total for one layer's work — useful for 1M-token training.

Cost detail — when CP becomes profitable

FlashAttention's compute on one rank for T tokens scales as O(B · T² · h · d_k). CP shards this across N ranks so each rank does O(B · T² / N · h · d_k) compute. The ring sends O(B · T · h_kv · d_k) bytes per rank per round, N - 1 rounds, so total O(B · T · h_kv · d_k · N) — which divides by NVLink bandwidth. Compute-to-comm ratio in CP:

compute / comm  ∝  T · (h · d_k) / (h_kv · d_k · N · BW_per_FLOP)

It's linear in T. The longer your context, the more compute per byte of K, V you move — so CP gets more efficient at longer contexts. This is the opposite of the bandwidth trap TP and FSDP have. CP is the only parallelism strategy in this series that likes long sequences.

Interactive · how much activation memory does SP/CP save?

Slide T and the parallelism axes. The widget plots activation memory per rank for: no SP/CP, SP only, CP only, both. Watch the long-context regime: at T = 128k, CP is the only thing that matters; below T = 4k, SP alone is enough.

Activation memory per rank · which parallelism reaches?
Memory accounting for a Llama-style block (with activation checkpointing). Bars stacked: matmul activations (TP-sharded) + LN/dropout (SP-able) + attention internals (CP-able).
no SP/CP
SP only
SP + CP
savings
Takeaway
SP shards the activations TP can't reach, at no comm cost — turn it on whenever TP is on. CP shards the attention itself by passing K/V around a ring and reusing FlashAttention's online-softmax accumulator; comm is hidden behind compute and the ratio improves with sequence length. SP is universal; CP is the long-context lever.