Flash Attention — block-tiled attention
The flagship Triton kernel. It composes every primitive from lessons 04–11 in service of one idea: never materialise the O(N²) attention matrix in HBM. Block the queries, stream the keys/values, carry the online (m, ℓ) recurrence, save 10× the bandwidth.
The problem in one diagram
The math is bandwidth-bound by S and P, which scale as N². Compute, by comparison, only doubles when N doubles (more correctly: the matmul flops scale as N²d). At long context, we're entirely waiting on HBM, doing the wrong amount of moving for the amount of math.
The Flash Attention idea
Three observations:
- Output O is small (N × d, with d ≪ N). We never need the full S or P; we only need to produce O.
- Softmax can be done online (lesson 10's recurrence). We can update the per-row (m, ℓ) one chunk of S at a time without ever writing S.
- P · V can also be done online. Since P is exp(S − m) / ℓ and we hold (m, ℓ) per row, we can stream PV and rescale as we go — same trick as the ℓ rescale.
Putting them together: block the Q rows, stream the K/V columns. For each block of queries:
- Initialise O tile = 0, m = −∞, ℓ = 0 — all in registers.
- For each block of (K, V): compute the partial scores Sblock, update the online (m, ℓ), update O with the rescale.
- At the end, divide O by ℓ and write it out.
No S or P ever leaves the registers. HBM traffic drops to just Q, K, V, O — that 6 + 2 = 8 MB instead of 390 MB at N = 8192.
The complete kernel — annotated
Here is a minimal Flash Attention forward (causal mask omitted for clarity; lesson 14 adds it):
@triton.jit
def flash_attn_fwd(
Q, K, V, O,
sqz, sqh, sqm, sqd, # Q strides (Z=batch, H=head, M=seq, D=head_dim)
skz, skh, skn, skd,
svz, svh, svn, svd,
soz, soh, som, sod,
Z, H, N_CTX,
BM: tl.constexpr, BN: tl.constexpr, D: tl.constexpr,
):
# 1 · which Q block am I, and which (batch, head)?
pid_m = tl.program_id(0)
bh = tl.program_id(1)
z = bh // H; h = bh % H
# 2 · pointers for this (z, h)
Q += z*sqz + h*sqh
K += z*skz + h*skh
V += z*svz + h*svh
O += z*soz + h*soh
offs_m = pid_m * BM + tl.arange(0, BM)
offs_d = tl.arange(0, D)
# 3 · load my Q block into registers (stays there for the whole K/V loop)
q = tl.load(Q + offs_m[:, None]*sqm + offs_d[None, :]*sqd,
mask=offs_m[:, None] < N_CTX, other=0.0)
# 4 · online state in registers
m = tl.full([BM], -float('inf'), dtype=tl.float32) # running max per row
l = tl.zeros([BM], dtype=tl.float32) # running ℓ per row
acc = tl.zeros([BM, D], dtype=tl.float32) # running O per row
# 5 · sliding window over K and V columns
softmax_scale = 1.0 / tl.sqrt(D.to(tl.float32))
for j in range(0, N_CTX, BN):
offs_n = j + tl.arange(0, BN)
k = tl.load(K + offs_n[:, None]*skn + offs_d[None, :]*skd,
mask=offs_n[:, None] < N_CTX, other=0.0)
v = tl.load(V + offs_n[:, None]*svn + offs_d[None, :]*svd,
mask=offs_n[:, None] < N_CTX, other=0.0)
# 5a · partial scores S_block = q · kᵀ (BM × BN)
s = tl.dot(q, tl.trans(k)) * softmax_scale
# 5b · ONLINE SOFTMAX UPDATE
m_new = tl.maximum(m, tl.max(s, axis=1)) # new row max
alpha = tl.exp(m - m_new) # rescale factor for old state
p = tl.exp(s - m_new[:, None]) # P_block w.r.t. new max
l = l*alpha + tl.sum(p, axis=1) # new ℓ
acc = acc*alpha[:, None] + tl.dot(p.to(v.dtype), v) # new O (rescaled then add this block)
m = m_new
# 6 · final normalise: O = acc / ℓ
acc = acc / l[:, None]
tl.store(O + offs_m[:, None]*som + offs_d[None, :]*sod,
acc.to(tl.bfloat16),
mask=offs_m[:, None] < N_CTX)
Reading the kernel: what each piece is doing
| Line | Lesson it uses | What it does |
|---|---|---|
pid_m = tl.program_id(0) | 03 | One program per BM-block of queries; second axis indexes (batch × head). |
q = tl.load(... mask=...) | 04 | Load the Q tile; mask the M boundary. |
m, l, acc = ... | 06, 10 | The online state. m is the running max, l is ℓ, acc is the running output before normalising. |
s = tl.dot(q, tl.trans(k)) | 05, 09 | The partial attention scores — a tile matmul, lowered to tensor cores. |
m_new, alpha, l, acc = ... | 10 | The online softmax update. alpha = exp(m - m_new) is the rescale factor. Both l and acc get multiplied by α — that's the line that retroactively corrects previous blocks for the new running max. |
acc = acc / l[:, None] | 10 | Final divide. Now acc is the true softmax(QKᵀ)·V row. |
Why the rescale is exact
After block j we want acc[i, :] = Σn ≤ j P[i, n] · V[n, :], where P[i, n] = exp(S[i, n] − m_final) / ℓ_final. We don't know mfinal yet, so we use the current running max m as a stand-in. When m grows to mnew, every previous block's exp(S − m) needs to become exp(S − mnew) = exp(S − m) · exp(m − mnew) = previous · α. So we multiply the accumulated acc and ℓ by α. The new block contributes Pnew · V at the new max already.
The bandwidth saving, quantified
| Quantity | Naive | Flash Attention |
|---|---|---|
| Reads of Q | 1× | 1× (once per program) |
| Reads of K | 1× | ~M/BM × (each K block read once per Q block) |
| Reads of V | 1× | ~M/BM × (same) |
| S / P writes & reads | ~3 N² | 0 |
| O writes | 1× | 1× |
The "K/V read M/BM times" looks bad, but at the actual numbers it isn't — M/BM is something like 64. The 64× re-read of K and V (a few MB) is far less than the eliminated N² write of S and P (hundreds of MB). The net is the famous 5–10× wall-clock speedup vs naive.
Causal masking — one extra line
For causal attention, mask scores where offs_n > offs_m before the softmax:
s = tl.where(offs_n[None, :] > offs_m[:, None], -float('inf'), s)
The −∞ entries contribute zero to exp(s − m) and zero to ℓ, so the recurrence still works. (Production kernels skip K/V blocks past the causal diagonal entirely — a 2× speedup on top of Flash Attention v1.)
What this kernel doesn't have (yet)
- The backward pass. Flash Attention's backward needs to reconstruct S on the fly during gradient computation. The original paper has the algorithm; lesson 14 sketches the structure.
- Dropout. A random mask is applied before P · V. RNG state per program.
- Sliding window / ALiBi / position biases. Added in the score computation (line 5a).
- FP8. Hopper-specific. Requires careful per-row dequant.
Production kernels like FlashAttention v3 and FlashInfer add all of these and more, plus warp specialisation (TMA producer warps + mma consumer warps). They beat hand-Triton by 20–40% on H100. Don't try to out-engineer them — when you can, just call FlashInfer.
Interactive · watch the tile sweep
Step through Q blocks and K/V blocks; see the online state update as each (K, V) block streams in.
What's next
The kernel exists. Two questions remain: how fast is it (lesson 13: autotune, num_stages, the full pitfall checklist) and how do you ship it (lesson 14: backward pass, profiling, the decision tree).