all_lessons/gpu_kernel_serving/12 · attention & flashlesson 12 / 17

Attention asymmetry & FlashAttention

Same math, two regimes: prefill is compute-bound, decode is bandwidth-bound. FlashAttention is one specific fix for the math layer; it does not by itself fix the systems layer. This lesson derives both, with the online softmax in seven lines.

The question this lesson answers

Lesson 02 said decode reads the whole KV cache per layer per token. Lesson 01 said the roofline favors arithmetic intensity over hundreds of FLOPs per byte. Why isn't naive attention just "compute the score matrix, softmax, multiply by V"? Because that materializes a T×T tensor in HBM — and at long T it is enormous and entirely wasted: every score is read exactly once after writing. FlashAttention removes that round trip. Then the systems story (paged KV, prefix reuse) handles the part FlashAttention doesn't touch.

Naive attention, byte-by-byte

For one head with sequence length T, Q∈ℝT×d, K∈ℝT×d, V∈ℝT×d, bf16 (b=2):

  1. S = Q · KT — kernel reads Q+K, writes a T×T score tile to HBM. writes 2·T² bytes.
  2. P = softmax(S) — kernel reads S, writes P. reads 2·T², writes 2·T².
  3. O = P · V — kernel reads P and V, writes O. reads 2·T² + 2·T·d, writes 2·T·d.

Intermediate-matrix HBM traffic adds up to 2·T² + 4·T² + 2·T² = 8·T² bytes. At T=8192 that is 512 MB per head, per layer, just for scores and probabilities no one keeps. Multiply by heads and layers and the kernel is hopelessly bandwidth-bound, despite doing only O(T²·d) FLOPs (which has high theoretical arithmetic intensity).

Subtle
The math has high arithmetic intensity (~d FLOPs per byte if we only counted Q, K, V, O traffic). Naive implementations destroy that by writing intermediates. FlashAttention's job is to recover the natural intensity by not writing S or P to HBM at all.

FlashAttention as a tile schedule

Idea: load a Q tile (Bq rows) and a K/V tile (Bk rows) into SRAM, compute their contribution to softmax, accumulate into the output, then move to the next K/V tile. Never write the score matrix to HBM. The complication: softmax needs the global max and sum, which you can't know until you've seen every K. The fix is an online softmax that updates the running max, scale, and accumulator as new K/V tiles arrive.

Q (sequence) Q tile Bq × d K and V (sequence) K/V tile 1 tile 2 tile 3 tile 4 running state m — running max ℓ — running denom O — running output all kept in SMEM/registers written to HBM only once, at end of Q tile causal mask: tiles beyond Q tile's row range are skipped during decode

The seven-line online softmax (per Q row, sweeping K/V tiles indexed by j):

# initialize
m, l, O = -inf, 0, 0
for j in tiles:                       # one tile of K/V at a time
    s = Q · K_j^T                     # tile of scores, in SRAM
    m_new = max(m, max(s))            # new running max
    p = exp(s - m_new)                # numerically safe
    l = exp(m - m_new) · l + sum(p)   # rescale old denom
    O = exp(m - m_new) · O + p · V_j  # rescale old numerator
    m = m_new
return O / l

Two things to take from this:

Why prefill loves FlashAttention and decode wants a variant

In prefill, T new tokens are processed at once: Q has T rows, the kernel walks them down as outer Q tiles and inner K/V tiles. Tile sizes (e.g., Bq=128, Bk=128) are chosen so each Q tile reuses one SRAM-resident K/V tile across many rows. Tensor cores stay busy; the kernel approaches its roofline FLOP rate.

In decode, Q has one row per sequence. The natural Q tile is tiny. There is little reuse along the Q axis, so the kernel's bottleneck is reading the entire history of K and V from HBM. Variants like FlashDecoding split the K dimension across thread blocks instead of the Q dimension to keep multiple SMs busy, then merge partial softmax results. FlashAttention-2/3 and FlashInfer include specialized batch-decode kernels that take a ragged batch (each sequence with its own T) and process it with one launch.

PhaseQ shapeBest parallelismTypical kernel
Prefill (long Q, dense)Bq × d, many rowsover Q rowsFlashAttention-2/3 (forward)
Decode (Bq = 1)1 × d per sequenceover K splits (FlashDecoding) or over sequencesFlashDecoding / FlashInfer batch-decode
Mixed batch (prefill+decode)ragged Q lengthsover sequences, with per-sequence tile sizesFlashInfer / vLLM unified attention
MLA (DeepSeek-style)compressed latentdifferent shape entirelyFlashMLA

What FlashAttention does not fix

This is the most common confusion in the topic. FlashAttention is a kernel-internal trick for the attention math. It does nothing about where K and V live in HBM. Specifically, it does not:

The clean separation
FlashAttention is the kernel for "compute attention without writing intermediates." PagedAttention is the layout for "store KV without reserving contiguous max-length buffers." Prefix/Radix caching is the policy for "don't recompute what you've computed before." All three can coexist; combining them is exactly what production engines do.

Interactive · attention HBM accountant

Move T and d. The widget computes naive attention HBM traffic (with materialized score matrix) vs FlashAttention traffic (no intermediate matrix). The ratio grows linearly in T/d.

Naive vs FlashAttention HBM traffic

"Naive" writes the T×T score matrix once and reads it twice; "Flash" reads Q, K, V once each and writes O once. Both ignore weight matrices (those are GEMM territory).

Causal masking, sliding windows, and friends

Production kernels accept extra constraints as cheap variations of the same tile schedule:

System-level consequences

Now we can name three things the serving engine must coordinate, all of which are downstream of the kernel choice:

  1. The selected attention backend must support the model's head dim, KV dtype, sliding window, GQA ratio, and the storage layout the rest of the system uses.
  2. If KV is stored non-contiguously (next lesson), the backend must accept a block table — i.e., it is the paged variant of FlashAttention/FlashInfer/etc.
  3. If decoder shapes are stable (lesson 06), the backend should be capturable into a CUDA graph. Some backends are friendlier to this than others.
Closing model for this lesson
Attention is a single mathematical operation with two distinct performance regimes. FlashAttention recovers the natural arithmetic intensity by streaming softmax. Whether your specific request runs the prefill kernel, the decode kernel, or a mixed kernel is a scheduler choice — but the kernel choices available are dictated by the layout choices in the next lesson.