vllm_lessons / 03 · FlashAttention lesson 3 / 12

FlashAttention

Same math, different schedule. Stream attention as a fused tile-loop with running softmax stats — never write the T×T score matrix to HBM.

Different paper, different problem

Before any math: PagedAttention (lesson 02) is about memory management of the KV cache across requests — a block table, an allocator, copy-on-write. FlashAttention is about the arithmetic kernel inside a single attention op. Different paper (Dao et al., NeurIPS 2022), different group (Stanford, not UC Berkeley), different problem.

They compose. vLLM ships a kernel that does Flash-style tiling over a Paged-style block layout. Confusing them is the most common mistake in vLLM conversations.

Orthogonality, summarized
PagedAttentionFlashAttention
concernwhere in HBM the KV liveshow the attention op reads/writes HBM
changesmemory layoutkernel schedule
numericsidentical to contiguousidentical to naive (up to float noise)
buys youfewer wasted KV bytesfewer HBM bytes moved per op

The pain of naive attention

Textbook attention computes three tensors and feeds them through one another:

S = QKT/√d    P = softmax(S)    O = PV

For a batch of B, h heads, sequence length T, head dim d, fp16: S and P each have shape (B, h, T, T). At T = 8192, B = 1, h = 32, fp16:

|P| = 32 · 81922 · 2 bytes  =  4.3 GB

Four gigabytes for one tensor in one layer of one request. Two things break:

The bottleneck flipped
Naive attention is memory-bound: the GPU spends most of its time waiting on HBM, not multiplying. Doing less HBM I/O directly buys wall-clock — see lesson 01, consequence B.

The insight: you never need the whole P

Look at any one row of O:

Oi = Σj softmax(Si,:)j · Vj

To compute it, the only thing you need from the full row of S is the softmax denominator Σj exp(Si,j − maxj Si,j) and a running weighted sum. If we can compute those incrementally as we scan across the row in chunks, we never have to materialize the row.

The online-softmax trick

Suppose we've already processed the first chunk of a row and tracked two scalars per row:

m1 = max of logits so far     ℓ1 = Σ exp(logit − m1)

A new chunk arrives with its own stats (m2, ℓ2). Combine them:

m = max(m1, m2)
ℓ = exp(m1 − m) · ℓ1 + exp(m2 − m) · ℓ2

This is just the identity exp(x − m1) = exp(x − m) · exp(m1 − m) applied to both partial sums and rebased to the new shared max. It's exact, not an approximation. Recurse over all chunks and you get the true softmax denominator over the full row, without ever having looked at more than one chunk at a time.

The same rescaling factor exp(mold − mnew) fixes the running unnormalized output:

O' ← exp(mold − mnew) · O' + Σj ∈ new chunk exp(Si,j − mnew) · Vj

At the end, divide by once. That's the whole algorithm.

The tile schedule

FlashAttention-2's loop structure (Q on the outer loop, the version vLLM uses):

for each Q-block (rows):                # outer loop, parallel across heads / batch
    init  (m, ℓ, O') = (-inf, 0, 0)  in SRAM    # per-row, for this Q-block
    load  Q-block                    into SRAM
    for each (K, V)-block (cols):   # inner loop, streams across all of K, V
        load K-block, V-block        into SRAM
        S_tile  = Q-block @ K-block^T / sqrt(d)             # in SRAM
        m_new   = max(m, max-of-S_tile-per-row)
        alpha   = exp(m - m_new)                            # rescale factor
        P_tile  = exp(S_tile - m_new)                       # unnormalized probs
        ℓ      = alpha * ℓ + sum(P_tile)
        O'     = alpha * O' + P_tile @ V-block              # streaming weighted sum
        m      = m_new
    O[Q-block] = O' / ℓ              # finalize: one division per row

Three things to notice:

  1. Nothing intermediate touches HBM. S_tile and P_tile live in registers/SRAM, sized to fit (~100 KB per SM). The full T×T matrix is never instantiated anywhere.
  2. HBM I/O is exactly Q, K, V, O. Q is read once per Q-block (one block load). K and V are each read once per inner iteration — that's still once each across the whole op if you stream them. O is written once at the end. That's it.
  3. The math is unchanged. See the rescaling derivation — every running update is algebraically equal to the corresponding term in the naive computation. The output matches naive attention to within float rounding (~1e-5 in fp32).
It's exact
FlashAttention is not a different model, not a different attention, not an approximation. It computes the same tensor as naive attention. The wins are entirely from changing when bytes move between HBM and SRAM.

Interactive · watch the online softmax tile by tile

The hardest part of FlashAttention to internalize is the running rescale. Below is the inner loop executed on one query row against four K, V tiles. Step through one tile at a time and watch m, , O' evolve — the only state the kernel keeps. The full row of scores is never materialized; the full T×T matrix is never imagined.

Online softmax · one Q row · 4 K/V tiles of 8 · d = 8
The orange tile is loaded into SRAM right now. Blue tiles are already integrated into the running stats and discarded. Grey tiles haven't been touched.
step
0 / 5
running m
running ℓ
α (rescale)
show the algorithm this widget runs
// One query row q; iterate over K, V in tiles of TILE rows.
m, l, O' = -inf, 0, zeros(d)       // running stats, live in SRAM
for t in tiles:
    S      = q @ K_t.T              // 1 × TILE,  in SRAM only
    m_new  = max(m, max(S))
    alpha  = exp(m - m_new)         // rescale factor for the old stats
    P      = exp(S - m_new)         // 1 × TILE,  in SRAM only
    l      = alpha * l + sum(P)
    O'     = alpha * O' + P @ V_t   // 1 × d,     in SRAM
    m      = m_new
return O' / l                       // one division per row at the end → O

How much HBM does each version move?

The numbers below assume one attention layer, fp16, B = 1. We'll count bytes that cross the HBM ↔ SRAM boundary.

Naive

Q, K, V, O reads/writes : 4 · h · T · d · 2 B
S write + read  : 2 · h · T2 · 2 B
P write + read  : 2 · h · T2 · 2 B
total ≈ 4hT2 · 2 B  +  8hTd B

The T2 term dominates as soon as T > 2d (which is always).

Flash

Q, K, V, O reads/writes : 4 · h · T · d · 2 B
(no S, no P in HBM)
total ≈ 8hTd B

The ratio at T = 8192, h = 32, d = 128, fp16:

naive ≈ 4 · 32 · 81922 · 2 B  +  8 · 32 · 8192 · 128 B  =  17.6 GB + 0.27 GB ≈ 17.8 GB
flash ≈ 8 · 32 · 8192 · 128 B  =  0.27 GB = 270 MB
reduction ≈ ~65×

Round numbers: naive ~20 GB, flash ~270 MB, ~75× less I/O. Translates to ~2–4× wall-clock speedup on an A100/H100 at long context. Not 75× wall-clock — HBM bandwidth isn't the only cost, and at long T the matmul FLOPs themselves become non-trivial — but it's the dominant axis.

The surprise: at short T, Flash doesn't help

Look at the formula again. Naive total is dominated by 4hT2 bytes only when that term beats the 8hTd constant. Solve:

4hT2 > 8hTd  ⇒  T > 2d

At d = 128, the crossover where the quadratic term starts to matter is T ≈ 256. But the practical win opens up later — you need the quadratic term to dominate the constant by a healthy multiple before kernel-launch and SRAM-management overhead pay back. Empirically that's around T ≈ 1024–2048. At T = 128 (small chat), Flash and naive move nearly identical HBM bytes; the kernel just has more bookkeeping.

Convince yourself with the widget below.

Interactive · HBM-traffic calculator

Sliders set the workload; bars show naive vs flash HBM traffic per attention layer (one forward pass, fp16, B = 1). Try:

HBM traffic per attention layer — naive vs flash
fp16, batch=1, single forward pass. Bars on a log scale because the gap can span four orders of magnitude.
naive HBM
flash HBM
reduction
dominant term (naive)
show the formulas this widget uses
// Per attention layer, fp16 (2 B), batch 1.
// Naive moves: Q, K, V, O (each once) + S and P (each written and read).
const qkvo  = 4 * h * T * d * 2;          // bytes
const s_io  = 2 * h * T * T * 2;          // write + read
const p_io  = 2 * h * T * T * 2;
const naive = qkvo + s_io + p_io;

// Flash keeps S, P in SRAM; moves only Q, K, V, O across HBM.
// (block_q, block_kv don't change the asymptote — they only affect SRAM
//  pressure and inner-loop count. We display them so you can see they
//  don't move the bars on the right.)
const flash = 4 * h * T * d * 2;

Composition with PagedAttention

In production vLLM, the attention kernel does both at once:

  1. For each Q-block (Flash's outer loop)…
  2. …iterate over the sequence's logical K, V positions in K-block-sized chunks…
  3. …look up each K-block's physical location via the block table (Paged's indirection)…
  4. …load that physical block into SRAM and run the Flash inner update.

The two ideas are perfectly compatible: Flash doesn't care where in HBM the K, V tile came from, and Paged doesn't care what arithmetic you do on the tile once you have it. The vLLM kernel just happens to do both.

Where the speedup goes (and doesn't)

regimebottleneckflash helps?
prefill, long context (T ≫ 1k)HBM traffic for S, Pyes, 2–4× wall-clock
decode, long contextKV cache readyes, same mechanism
prefill, short context (T ≲ 256)QKVO I/O + launch overheadmarginal
training (backward)same as fwd + recomputeeven more (Flash-2 handles backward specifically)
Takeaway
Naive attention writes a tensor it doesn't need to keep — the T×T scores — to HBM and reads it back. Flash never writes it. The math is identical; the byte movement is ~75× less at long T, giving ~2–4× wall-clock. The win is real only past T ≈ 1024; below that, you're paying overhead for a number that's already small.

Next: continuous batching — the scheduling-side win that fixes the other half of vLLM's 24× throughput claim. It requires paged KV (lesson 02) to work, which is why we did paging first.