all_lessons / Triton kernels / lessons / 12 · flash attention lesson 12 / 14

Flash Attention — block-tiled attention

The flagship Triton kernel. It composes every primitive from lessons 04–11 in service of one idea: never materialise the O(N²) attention matrix in HBM. Block the queries, stream the keys/values, carry the online (m, ℓ) recurrence, save 10× the bandwidth.

The problem in one diagram

Naive attention — three kernels, N² intermediates in HBM Q·Kᵀ → S (N×N) S (HBM) softmax → P (N×N) P (HBM) P·V → O (N×d) HBM traffic at N = 8192, d = 128, bf16: Q, K, V reads: 3 × 8192 × 128 × 2 = 6 MB S write: 8192 × 8192 × 2 = 128 MB ← scales as N²! P write + read: 8192 × 8192 × 2 × 2 = 256 MB O write: 8192 × 128 × 2 = 2 MB Total: ~390 MB. At N=16K it's 1.5 GB. The N² intermediates dominate.

The math is bandwidth-bound by S and P, which scale as . Compute, by comparison, only doubles when N doubles (more correctly: the matmul flops scale as N²d). At long context, we're entirely waiting on HBM, doing the wrong amount of moving for the amount of math.

The Flash Attention idea

Three observations:

  1. Output O is small (N × d, with dN). We never need the full S or P; we only need to produce O.
  2. Softmax can be done online (lesson 10's recurrence). We can update the per-row (m, ℓ) one chunk of S at a time without ever writing S.
  3. P · V can also be done online. Since P is exp(S − m) / ℓ and we hold (m, ℓ) per row, we can stream PV and rescale as we go — same trick as the ℓ rescale.

Putting them together: block the Q rows, stream the K/V columns. For each block of queries:

  1. Initialise O tile = 0, m = −∞, ℓ = 0 — all in registers.
  2. For each block of (K, V): compute the partial scores Sblock, update the online (m, ℓ), update O with the rescale.
  3. At the end, divide O by ℓ and write it out.

No S or P ever leaves the registers. HBM traffic drops to just Q, K, V, O — that 6 + 2 = 8 MB instead of 390 MB at N = 8192.

The complete kernel — annotated

Here is a minimal Flash Attention forward (causal mask omitted for clarity; lesson 14 adds it):

@triton.jit
def flash_attn_fwd(
    Q, K, V, O,
    sqz, sqh, sqm, sqd,                       # Q strides (Z=batch, H=head, M=seq, D=head_dim)
    skz, skh, skn, skd,
    svz, svh, svn, svd,
    soz, soh, som, sod,
    Z, H, N_CTX,
    BM: tl.constexpr, BN: tl.constexpr, D: tl.constexpr,
):
    # 1 · which Q block am I, and which (batch, head)?
    pid_m = tl.program_id(0)
    bh    = tl.program_id(1)
    z = bh // H; h = bh % H

    # 2 · pointers for this (z, h)
    Q += z*sqz + h*sqh
    K += z*skz + h*skh
    V += z*svz + h*svh
    O += z*soz + h*soh

    offs_m = pid_m * BM + tl.arange(0, BM)
    offs_d = tl.arange(0, D)

    # 3 · load my Q block into registers (stays there for the whole K/V loop)
    q = tl.load(Q + offs_m[:, None]*sqm + offs_d[None, :]*sqd,
                mask=offs_m[:, None] < N_CTX, other=0.0)

    # 4 · online state in registers
    m   = tl.full([BM], -float('inf'), dtype=tl.float32)        # running max per row
    l   = tl.zeros([BM], dtype=tl.float32)                       # running ℓ per row
    acc = tl.zeros([BM, D], dtype=tl.float32)                    # running O per row

    # 5 · sliding window over K and V columns
    softmax_scale = 1.0 / tl.sqrt(D.to(tl.float32))
    for j in range(0, N_CTX, BN):
        offs_n = j + tl.arange(0, BN)
        k = tl.load(K + offs_n[:, None]*skn + offs_d[None, :]*skd,
                    mask=offs_n[:, None] < N_CTX, other=0.0)
        v = tl.load(V + offs_n[:, None]*svn + offs_d[None, :]*svd,
                    mask=offs_n[:, None] < N_CTX, other=0.0)

        # 5a · partial scores  S_block = q · kᵀ  (BM × BN)
        s = tl.dot(q, tl.trans(k)) * softmax_scale

        # 5b · ONLINE SOFTMAX UPDATE
        m_new = tl.maximum(m, tl.max(s, axis=1))                  # new row max
        alpha = tl.exp(m - m_new)                                  # rescale factor for old state
        p     = tl.exp(s - m_new[:, None])                         # P_block w.r.t. new max
        l     = l*alpha + tl.sum(p, axis=1)                        # new ℓ
        acc   = acc*alpha[:, None] + tl.dot(p.to(v.dtype), v)      # new O (rescaled then add this block)
        m     = m_new

    # 6 · final normalise: O = acc / ℓ
    acc = acc / l[:, None]
    tl.store(O + offs_m[:, None]*som + offs_d[None, :]*sod,
             acc.to(tl.bfloat16),
             mask=offs_m[:, None] < N_CTX)

Reading the kernel: what each piece is doing

LineLesson it usesWhat it does
pid_m = tl.program_id(0)03One program per BM-block of queries; second axis indexes (batch × head).
q = tl.load(... mask=...)04Load the Q tile; mask the M boundary.
m, l, acc = ...06, 10The online state. m is the running max, l is ℓ, acc is the running output before normalising.
s = tl.dot(q, tl.trans(k))05, 09The partial attention scores — a tile matmul, lowered to tensor cores.
m_new, alpha, l, acc = ...10The online softmax update. alpha = exp(m - m_new) is the rescale factor. Both l and acc get multiplied by α — that's the line that retroactively corrects previous blocks for the new running max.
acc = acc / l[:, None]10Final divide. Now acc is the true softmax(QKᵀ)·V row.

Why the rescale is exact

After block j we want acc[i, :] = Σn ≤ j P[i, n] · V[n, :], where P[i, n] = exp(S[i, n] − m_final) / ℓ_final. We don't know mfinal yet, so we use the current running max m as a stand-in. When m grows to mnew, every previous block's exp(S − m) needs to become exp(S − mnew) = exp(S − m) · exp(m − mnew) = previous · α. So we multiply the accumulated acc and ℓ by α. The new block contributes Pnew · V at the new max already.

One program: BM Q rows · slides over BN K/V columns Q tile BM × D held in registers K/V blocks streamed in BN × D BN × D BN × D acc + (m, ℓ) BM × D + BM + BM registers only For each (K, V) block: S = Q · Kᵀ (tl.dot) m_new = max(m, rowmax S); α = exp(m − m_new) P = exp(S − m_new) ℓ ← α · ℓ + rowsum P acc ← α · acc + P · V (tl.dot again)

The bandwidth saving, quantified

QuantityNaiveFlash Attention
Reads of Q1× (once per program)
Reads of K~M/BM × (each K block read once per Q block)
Reads of V~M/BM × (same)
S / P writes & reads~3 N²0
O writes

The "K/V read M/BM times" looks bad, but at the actual numbers it isn't — M/BM is something like 64. The 64× re-read of K and V (a few MB) is far less than the eliminated N² write of S and P (hundreds of MB). The net is the famous 5–10× wall-clock speedup vs naive.

Causal masking — one extra line

For causal attention, mask scores where offs_n > offs_m before the softmax:

s = tl.where(offs_n[None, :] > offs_m[:, None], -float('inf'), s)

The −∞ entries contribute zero to exp(s − m) and zero to ℓ, so the recurrence still works. (Production kernels skip K/V blocks past the causal diagonal entirely — a 2× speedup on top of Flash Attention v1.)

What this kernel doesn't have (yet)

Production kernels like FlashAttention v3 and FlashInfer add all of these and more, plus warp specialisation (TMA producer warps + mma consumer warps). They beat hand-Triton by 20–40% on H100. Don't try to out-engineer them — when you can, just call FlashInfer.

When to write a Triton attention kernel
When FlashInfer doesn't cover your variant: a custom mask, a learned position bias, fused with a downstream op, MLA's compressed-KV path, etc. For vanilla MHA / GQA, call FlashInfer.

Interactive · watch the tile sweep

Step through Q blocks and K/V blocks; see the online state update as each (K, V) block streams in.

Flash Attention block sweep

One Q tile (highlighted) holds steady. K/V blocks slide left-to-right. The online (m, ℓ) pair updates after each.

What's next

The kernel exists. Two questions remain: how fast is it (lesson 13: autotune, num_stages, the full pitfall checklist) and how do you ship it (lesson 14: backward pass, profiling, the decision tree).