FlashAttention
Same math, different schedule. Stream attention as a fused tile-loop with running softmax stats — never write the T×T score matrix to HBM.
Different paper, different problem
Before any math: PagedAttention (lesson 02) is about memory management of the KV cache across requests — a block table, an allocator, copy-on-write. FlashAttention is about the arithmetic kernel inside a single attention op. Different paper (Dao et al., NeurIPS 2022), different group (Stanford, not UC Berkeley), different problem.
They compose. vLLM ships a kernel that does Flash-style tiling over a Paged-style block layout. Confusing them is the most common mistake in vLLM conversations.
| PagedAttention | FlashAttention | |
|---|---|---|
| concern | where in HBM the KV lives | how the attention op reads/writes HBM |
| changes | memory layout | kernel schedule |
| numerics | identical to contiguous | identical to naive (up to float noise) |
| buys you | fewer wasted KV bytes | fewer HBM bytes moved per op |
The pain of naive attention
Textbook attention computes three tensors and feeds them through one another:
For a batch of B, h heads, sequence length T, head dim d, fp16: S and P each have shape (B, h, T, T). At T = 8192, B = 1, h = 32, fp16:
Four gigabytes for one tensor in one layer of one request. Two things break:
- You can't fit it. Even before you compute O, the activation memory has blown past what HBM can spare.
- You can't move it fast enough. Each of S, P gets written to HBM and read back. That's ~17 GB of HBM traffic for the score+prob tensors alone, per layer. On an H100 at 3 TB/s, that's ~6 ms per layer of pure memory traffic — and the math takes microseconds.
The insight: you never need the whole P
Look at any one row of O:
To compute it, the only thing you need from the full row of S is the softmax denominator Σj exp(Si,j − maxj Si,j) and a running weighted sum. If we can compute those incrementally as we scan across the row in chunks, we never have to materialize the row.
The online-softmax trick
Suppose we've already processed the first chunk of a row and tracked two scalars per row:
A new chunk arrives with its own stats (m2, ℓ2). Combine them:
ℓ = exp(m1 − m) · ℓ1 + exp(m2 − m) · ℓ2
This is just the identity exp(x − m1) = exp(x − m) · exp(m1 − m) applied to both partial sums and rebased to the new shared max. It's exact, not an approximation. Recurse over all chunks and you get the true softmax denominator over the full row, without ever having looked at more than one chunk at a time.
The same rescaling factor exp(mold − mnew) fixes the running unnormalized output:
At the end, divide by ℓ once. That's the whole algorithm.
The tile schedule
FlashAttention-2's loop structure (Q on the outer loop, the version vLLM uses):
for each Q-block (rows): # outer loop, parallel across heads / batch
init (m, ℓ, O') = (-inf, 0, 0) in SRAM # per-row, for this Q-block
load Q-block into SRAM
for each (K, V)-block (cols): # inner loop, streams across all of K, V
load K-block, V-block into SRAM
S_tile = Q-block @ K-block^T / sqrt(d) # in SRAM
m_new = max(m, max-of-S_tile-per-row)
alpha = exp(m - m_new) # rescale factor
P_tile = exp(S_tile - m_new) # unnormalized probs
ℓ = alpha * ℓ + sum(P_tile)
O' = alpha * O' + P_tile @ V-block # streaming weighted sum
m = m_new
O[Q-block] = O' / ℓ # finalize: one division per row
Three things to notice:
- Nothing intermediate touches HBM.
S_tileandP_tilelive in registers/SRAM, sized to fit (~100 KB per SM). The full T×T matrix is never instantiated anywhere. - HBM I/O is exactly Q, K, V, O. Q is read once per Q-block (one block load). K and V are each read once per inner iteration — that's still once each across the whole op if you stream them. O is written once at the end. That's it.
- The math is unchanged. See the rescaling derivation — every running update is algebraically equal to the corresponding term in the naive computation. The output matches naive attention to within float rounding (~1e-5 in fp32).
Interactive · watch the online softmax tile by tile
The hardest part of FlashAttention to internalize is the running rescale. Below is the inner loop executed on one query row against four K, V tiles. Step through one tile at a time and watch m, ℓ, O' evolve — the only state the kernel keeps. The full row of scores is never materialized; the full T×T matrix is never imagined.
- The synthetic data is chosen so a bigger max shows up in tile #2 — watch the running α shrink the old O' and ℓ on that step. That rescale is the entire trick.
- Hit auto-step, then watch the O' bars walk. At the very end one division by ℓ finalizes the output.
- The "in HBM" / "in SRAM" badges are the point: only Q row, K tile, V tile, m, ℓ, O' ever sit in SRAM, and only O ever leaves at the end.
How much HBM does each version move?
The numbers below assume one attention layer, fp16, B = 1. We'll count bytes that cross the HBM ↔ SRAM boundary.
Naive
S write + read : 2 · h · T2 · 2 B
P write + read : 2 · h · T2 · 2 B
total ≈ 4hT2 · 2 B + 8hTd B
The T2 term dominates as soon as T > 2d (which is always).
Flash
(no S, no P in HBM)
total ≈ 8hTd B
The ratio at T = 8192, h = 32, d = 128, fp16:
flash ≈ 8 · 32 · 8192 · 128 B = 0.27 GB = 270 MB
reduction ≈ ~65×
Round numbers: naive ~20 GB, flash ~270 MB, ~75× less I/O. Translates to ~2–4× wall-clock speedup on an A100/H100 at long context. Not 75× wall-clock — HBM bandwidth isn't the only cost, and at long T the matmul FLOPs themselves become non-trivial — but it's the dominant axis.
The surprise: at short T, Flash doesn't help
Look at the formula again. Naive total is dominated by 4hT2 bytes only when that term beats the 8hTd constant. Solve:
At d = 128, the crossover where the quadratic term starts to matter is T ≈ 256. But the practical win opens up later — you need the quadratic term to dominate the constant by a healthy multiple before kernel-launch and SRAM-management overhead pay back. Empirically that's around T ≈ 1024–2048. At T = 128 (small chat), Flash and naive move nearly identical HBM bytes; the kernel just has more bookkeeping.
Convince yourself with the widget below.
Interactive · HBM-traffic calculator
Sliders set the workload; bars show naive vs flash HBM traffic per attention layer (one forward pass, fp16, B = 1). Try:
- Default (T = 8192, h = 32, d = 128): see the ~70× gap that justifies every Flash deployment.
- Drop T to 256: Flash is barely ahead. Note which term dominates the naive bar — it's no longer P+S.
- Push T to 16384: the T2 term in naive explodes; Flash grows linearly.
- Bump d to 256: the QKVO constant doubles. Flash's floor rises; the crossover moves out.
Composition with PagedAttention
In production vLLM, the attention kernel does both at once:
- For each Q-block (Flash's outer loop)…
- …iterate over the sequence's logical K, V positions in K-block-sized chunks…
- …look up each K-block's physical location via the block table (Paged's indirection)…
- …load that physical block into SRAM and run the Flash inner update.
The two ideas are perfectly compatible: Flash doesn't care where in HBM the K, V tile came from, and Paged doesn't care what arithmetic you do on the tile once you have it. The vLLM kernel just happens to do both.
Where the speedup goes (and doesn't)
| regime | bottleneck | flash helps? |
|---|---|---|
| prefill, long context (T ≫ 1k) | HBM traffic for S, P | yes, 2–4× wall-clock |
| decode, long context | KV cache read | yes, same mechanism |
| prefill, short context (T ≲ 256) | QKVO I/O + launch overhead | marginal |
| training (backward) | same as fwd + recompute | even more (Flash-2 handles backward specifically) |
Next: continuous batching — the scheduling-side win that fixes the other half of vLLM's 24× throughput claim. It requires paged KV (lesson 02) to work, which is why we did paging first.