Attention asymmetry & FlashAttention
Same math, two regimes: prefill is compute-bound, decode is bandwidth-bound. FlashAttention is one specific fix for the math layer; it does not by itself fix the systems layer. This lesson derives both, with the online softmax in seven lines.
The question this lesson answers
Lesson 02 said decode reads the whole KV cache per layer per token. Lesson 01 said the roofline favors arithmetic intensity over hundreds of FLOPs per byte. Why isn't naive attention just "compute the score matrix, softmax, multiply by V"? Because that materializes a T×T tensor in HBM — and at long T it is enormous and entirely wasted: every score is read exactly once after writing. FlashAttention removes that round trip. Then the systems story (paged KV, prefix reuse) handles the part FlashAttention doesn't touch.
Naive attention, byte-by-byte
For one head with sequence length T, Q∈ℝT×d, K∈ℝT×d, V∈ℝT×d, bf16 (b=2):
S = Q · KT— kernel reads Q+K, writes a T×T score tile to HBM. writes 2·T² bytes.P = softmax(S)— kernel reads S, writes P. reads 2·T², writes 2·T².O = P · V— kernel reads P and V, writes O. reads 2·T² + 2·T·d, writes 2·T·d.
Intermediate-matrix HBM traffic adds up to 2·T² + 4·T² + 2·T² = 8·T² bytes. At T=8192 that is 512 MB per head, per layer, just for scores and probabilities no one keeps. Multiply by heads and layers and the kernel is hopelessly bandwidth-bound, despite doing only O(T²·d) FLOPs (which has high theoretical arithmetic intensity).
FlashAttention as a tile schedule
Idea: load a Q tile (Bq rows) and a K/V tile (Bk rows) into SRAM, compute their contribution to softmax, accumulate into the output, then move to the next K/V tile. Never write the score matrix to HBM. The complication: softmax needs the global max and sum, which you can't know until you've seen every K. The fix is an online softmax that updates the running max, scale, and accumulator as new K/V tiles arrive.
The seven-line online softmax (per Q row, sweeping K/V tiles indexed by j):
# initialize
m, l, O = -inf, 0, 0
for j in tiles: # one tile of K/V at a time
s = Q · K_j^T # tile of scores, in SRAM
m_new = max(m, max(s)) # new running max
p = exp(s - m_new) # numerically safe
l = exp(m - m_new) · l + sum(p) # rescale old denom
O = exp(m - m_new) · O + p · V_j # rescale old numerator
m = m_new
return O / l
Two things to take from this:
- The score matrix never lands in HBM. The kernel reads each K and V tile once, does its math, and discards it. HBM traffic drops from
O(T²)intermediate writes toO(T·d)KV reads. - The kernel produces identical output to the naive version up to floating-point order. The "online" rescaling preserves softmax exactly; it is not an approximation.
Why prefill loves FlashAttention and decode wants a variant
In prefill, T new tokens are processed at once: Q has T rows, the kernel walks them down as outer Q tiles and inner K/V tiles. Tile sizes (e.g., Bq=128, Bk=128) are chosen so each Q tile reuses one SRAM-resident K/V tile across many rows. Tensor cores stay busy; the kernel approaches its roofline FLOP rate.
In decode, Q has one row per sequence. The natural Q tile is tiny. There is little reuse along the Q axis, so the kernel's bottleneck is reading the entire history of K and V from HBM. Variants like FlashDecoding split the K dimension across thread blocks instead of the Q dimension to keep multiple SMs busy, then merge partial softmax results. FlashAttention-2/3 and FlashInfer include specialized batch-decode kernels that take a ragged batch (each sequence with its own T) and process it with one launch.
| Phase | Q shape | Best parallelism | Typical kernel |
|---|---|---|---|
| Prefill (long Q, dense) | Bq × d, many rows | over Q rows | FlashAttention-2/3 (forward) |
| Decode (Bq = 1) | 1 × d per sequence | over K splits (FlashDecoding) or over sequences | FlashDecoding / FlashInfer batch-decode |
| Mixed batch (prefill+decode) | ragged Q lengths | over sequences, with per-sequence tile sizes | FlashInfer / vLLM unified attention |
| MLA (DeepSeek-style) | compressed latent | different shape entirely | FlashMLA |
What FlashAttention does not fix
This is the most common confusion in the topic. FlashAttention is a kernel-internal trick for the attention math. It does nothing about where K and V live in HBM. Specifically, it does not:
- Reduce KV cache size — it still reads the same K and V tensors.
- Handle non-contiguous KV — naive FA assumes a contiguous KV layout per sequence; supporting paged KV requires a kernel variant that follows a block table (lesson 04).
- Share KV across requests — that's the prefix-reuse problem (lesson 05).
- Hide launch overhead — that's CUDA graphs (lesson 06).
Interactive · attention HBM accountant
Move T and d. The widget computes naive attention HBM traffic (with materialized score matrix) vs FlashAttention traffic (no intermediate matrix). The ratio grows linearly in T/d.
Causal masking, sliding windows, and friends
Production kernels accept extra constraints as cheap variations of the same tile schedule:
- Causal mask: for each Q row, skip K/V tiles entirely above its column. Cuts work roughly in half during prefill.
- Sliding window: only attend to the last W tokens. The K/V tile loop bounds change; everything else is unchanged.
- ALiBi / position biases: add a position-dependent offset to scores before the running-max update. No HBM cost.
- GQA / MQA: several Q heads share one KV head. The K/V tile read serves multiple Q heads in the kernel — that's the only way GQA actually helps decode bandwidth.
System-level consequences
Now we can name three things the serving engine must coordinate, all of which are downstream of the kernel choice:
- The selected attention backend must support the model's head dim, KV dtype, sliding window, GQA ratio, and the storage layout the rest of the system uses.
- If KV is stored non-contiguously (next lesson), the backend must accept a block table — i.e., it is the paged variant of FlashAttention/FlashInfer/etc.
- If decoder shapes are stable (lesson 06), the backend should be capturable into a CUDA graph. Some backends are friendlier to this than others.