Softmax — the online reduction
Three implementations side by side: naive 3-pass (overflows), 2-pass numerically stable (works but reads twice), 1-pass online (the Milakov–Gimelshein recurrence, the trick that makes Flash Attention possible). The third one is the one you'll reuse for the rest of your Triton career.
The three implementations in code
1 · Naive 3-pass (broken)
# pass 1: compute exp
e = tl.exp(x)
# pass 2: sum
s = tl.sum(e, axis=0)
# pass 3: divide
y = e / s
Overflows for any moderately large x. tl.exp(20) is fp32-OK; tl.exp(100) is +inf. Don't use this.
2 · Numerically stable 2-pass
# pass 1: max
m = tl.max(x, axis=0)
# pass 2: sum of shifted exponentials
e = tl.exp(x - m)
s = tl.sum(e, axis=0)
# pass 3: divide
y = e / s
This works. It's what you write when the whole row fits in one tile. The cost: two reductions over the row, each a warp-shuffle + SMEM round-trip. For row lengths up to ~4K this is fine.
3 · Online 1-pass (the Milakov–Gimelshein recurrence)
Stream the row in chunks. Carry forward only the running max m and the running sum-of-exp ℓ. Update both together:
# initial state
m = -float('inf')
l = 0.0
for chunk in chunks_of(row):
cm = tl.max(chunk, axis=0) # max of this chunk
new_m = tl.maximum(m, cm)
# rescale old ℓ to the new max, add chunk's ℓ
l = l * tl.exp(m - new_m) + tl.sum(tl.exp(chunk - new_m), axis=0)
m = new_m
# final pass — recompute each chunk and divide by ℓ
# (or stash partial e values during the first pass and divide them here)
The magic is the rescale: l = l * exp(m - new_m) + chunk_l_at_new_m. When the max grows, the old ℓ (which was computed against the old max) is rescaled by exp(old_m − new_m) — a number ≤ 1 — to "correct" for the change.
The Triton kernel — row-per-program, 2-pass version
For typical sequence lengths the 2-pass version is the right one to write. The 1-pass online recurrence we'll save for Flash Attention (lesson 12).
@triton.jit
def softmax_kernel(x_ptr, y_ptr, stride_x_row, stride_y_row, N, BLOCK: tl.constexpr):
row = tl.program_id(0)
offs = tl.arange(0, BLOCK)
mask = offs < N
# row pointers
x_row = x_ptr + row * stride_x_row
y_row = y_ptr + row * stride_y_row
# load (one row per program); use -inf so masked lanes lose the max
x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
# 2-pass softmax in registers
m = tl.max(x, axis=0)
e = tl.exp(x - m)
s = tl.sum(e, axis=0)
y = e / s
tl.store(y_row + offs, y, mask=mask)
The other=-float('inf') on the load is the lesson-04 boundary rule: lanes past N must not win the max. Sub them out with the identity of max, which is −∞.
The wrapper
def softmax(x: torch.Tensor) -> torch.Tensor:
M, N = x.shape # M rows, each of length N
y = torch.empty_like(x)
# one program per row; tile width = nearest power-of-two ≥ N
BLOCK = triton.next_power_of_2(N)
num_warps = 4
if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16
softmax_kernel[(M,)](
x, y, x.stride(0), y.stride(0), N,
BLOCK=BLOCK, num_warps=num_warps,
)
return y
Two things to call out:
next_power_of_2(N): tile sizes must be compile-time powers of two. Round up; the boundary mask handles the slack.num_warpsscales with BLOCK: more warps let the reductions parallelise over more lanes. The autotuner can pick this, but for softmax a simple staircase is fine.
What happens when N exceeds BLOCK
If the row is longer than the max tile (~8192), the one-program-per-row approach breaks. You can't materialise more than ~8K elements in registers. Options:
- Multiple programs per row, online merge. Each program processes a chunk, writes its (m, ℓ) pair to a small auxiliary buffer; a second kernel merges them and rescales. The full online recurrence.
- Per-program streaming loop. One program walks its row in BK-wide chunks, applying the online recurrence in registers. The kernel touches the row twice (once for the recurrence, once for the final divide), but everything stays per-program and you avoid the auxiliary buffer.
Both are variants of the same idea. Lesson 12 uses option 2 inside Flash Attention.
The recurrence, derived from scratch
Suppose you've processed elements x[0..i] and have m = max(x[0..i]) and ℓ = Σj ≤ i exj − m. The next chunk has max mc and ℓc = Σchunk exj − mc.
The new combined max is m' = max(m, mc). The new combined ℓ should be Σj ≤ new end exj − m'. Split the sum at the chunk boundary:
ℓ' = Σj ≤ i exj − m' + Σchunk exj − m'
Multiply numerator and denominator of each term by 1 in the form e−m_x + m_x:
ℓ' = em − m' · Σj ≤ i exj − m + emc − m' · Σchunk exj − mc
The two sums are just ℓ and ℓc. So:
ℓ' = em − m' · ℓ + emc − m' · ℓc
That's the recurrence in three lines of algebra. The first term rescales the old ℓ for the change in max; the second contributes the chunk's ℓ rescaled to the same max. Both factors are ≤ 1, so there's no overflow.
Reading the kernel like an interviewer would
A common interview question on Triton: "What's the difference between the 2-pass softmax and Flash Attention's softmax?" Answer in one breath:
- The 2-pass version assumes the whole row fits in registers (or, on a tile, in one program).
- Flash Attention's softmax is the online recurrence above, applied inside the K-loop of attention. It produces (m, ℓ) one chunk of attention scores at a time, never materialising the full N×N matrix in HBM.
You will be expected to write the recurrence on a whiteboard. It's three lines. Memorise them.
Interactive · which softmax is right for this row length?
Drag the row length. See which variant fits in a tile, what BLOCK and num_warps Triton would pick, and when you need to fall back to the online recurrence.
What's next
Softmax fuses a max reduction with a sum reduction with a divide. Lesson 11 fuses a different pair: a mean (or RMS) reduction with a scale. The kernel is half as long but the pattern — stream once, fuse the stat with the scaling — is identical.