all_lessons / Triton kernels / lessons / 10 · softmax lesson 10 / 14

Softmax — the online reduction

Three implementations side by side: naive 3-pass (overflows), 2-pass numerically stable (works but reads twice), 1-pass online (the Milakov–Gimelshein recurrence, the trick that makes Flash Attention possible). The third one is the one you'll reuse for the rest of your Triton career.

The three implementations in code

1 · Naive 3-pass (broken)

# pass 1: compute exp
e = tl.exp(x)
# pass 2: sum
s = tl.sum(e, axis=0)
# pass 3: divide
y = e / s

Overflows for any moderately large x. tl.exp(20) is fp32-OK; tl.exp(100) is +inf. Don't use this.

2 · Numerically stable 2-pass

# pass 1: max
m = tl.max(x, axis=0)
# pass 2: sum of shifted exponentials
e = tl.exp(x - m)
s = tl.sum(e, axis=0)
# pass 3: divide
y = e / s

This works. It's what you write when the whole row fits in one tile. The cost: two reductions over the row, each a warp-shuffle + SMEM round-trip. For row lengths up to ~4K this is fine.

3 · Online 1-pass (the Milakov–Gimelshein recurrence)

Stream the row in chunks. Carry forward only the running max m and the running sum-of-exp ℓ. Update both together:

# initial state
m = -float('inf')
l = 0.0

for chunk in chunks_of(row):
    cm = tl.max(chunk, axis=0)               # max of this chunk
    new_m = tl.maximum(m, cm)
    # rescale old ℓ to the new max, add chunk's ℓ
    l = l * tl.exp(m - new_m) + tl.sum(tl.exp(chunk - new_m), axis=0)
    m = new_m

# final pass — recompute each chunk and divide by ℓ
# (or stash partial e values during the first pass and divide them here)

The magic is the rescale: l = l * exp(m - new_m) + chunk_l_at_new_m. When the max grows, the old ℓ (which was computed against the old max) is rescaled by exp(old_m − new_m) — a number ≤ 1 — to "correct" for the change.

Why this matters
In a 2-pass softmax you have to load the row twice — once for max, once for sum. In an online 1-pass softmax you load it once. For attention, the "row" is the row of attention scores Sij = Qi · Kj, which is N²-shaped. The 1-pass version is what makes Flash Attention's "never materialise the full attention matrix" possible.

The Triton kernel — row-per-program, 2-pass version

For typical sequence lengths the 2-pass version is the right one to write. The 1-pass online recurrence we'll save for Flash Attention (lesson 12).

@triton.jit
def softmax_kernel(x_ptr, y_ptr, stride_x_row, stride_y_row, N, BLOCK: tl.constexpr):
    row = tl.program_id(0)
    offs = tl.arange(0, BLOCK)
    mask = offs < N

    # row pointers
    x_row = x_ptr + row * stride_x_row
    y_row = y_ptr + row * stride_y_row

    # load (one row per program); use -inf so masked lanes lose the max
    x = tl.load(x_row + offs, mask=mask, other=-float('inf'))

    # 2-pass softmax in registers
    m = tl.max(x, axis=0)
    e = tl.exp(x - m)
    s = tl.sum(e, axis=0)
    y = e / s

    tl.store(y_row + offs, y, mask=mask)

The other=-float('inf') on the load is the lesson-04 boundary rule: lanes past N must not win the max. Sub them out with the identity of max, which is −∞.

The wrapper

def softmax(x: torch.Tensor) -> torch.Tensor:
    M, N = x.shape                            # M rows, each of length N
    y = torch.empty_like(x)
    # one program per row; tile width = nearest power-of-two ≥ N
    BLOCK = triton.next_power_of_2(N)
    num_warps = 4
    if BLOCK >= 2048: num_warps = 8
    if BLOCK >= 4096: num_warps = 16
    softmax_kernel[(M,)](
        x, y, x.stride(0), y.stride(0), N,
        BLOCK=BLOCK, num_warps=num_warps,
    )
    return y

Two things to call out:

What happens when N exceeds BLOCK

If the row is longer than the max tile (~8192), the one-program-per-row approach breaks. You can't materialise more than ~8K elements in registers. Options:

  1. Multiple programs per row, online merge. Each program processes a chunk, writes its (m, ℓ) pair to a small auxiliary buffer; a second kernel merges them and rescales. The full online recurrence.
  2. Per-program streaming loop. One program walks its row in BK-wide chunks, applying the online recurrence in registers. The kernel touches the row twice (once for the recurrence, once for the final divide), but everything stays per-program and you avoid the auxiliary buffer.

Both are variants of the same idea. Lesson 12 uses option 2 inside Flash Attention.

The recurrence, derived from scratch

Suppose you've processed elements x[0..i] and have m = max(x[0..i]) and ℓ = Σj ≤ i exj − m. The next chunk has max mc and ℓc = Σchunk exj − mc.

The new combined max is m' = max(m, mc). The new combined ℓ should be Σj ≤ new end exj − m'. Split the sum at the chunk boundary:

ℓ' = Σj ≤ i exj − m' + Σchunk exj − m'

Multiply numerator and denominator of each term by 1 in the form e−m_x + m_x:

ℓ' = em − m' · Σj ≤ i exj − m + emc − m' · Σchunk exj − mc

The two sums are just ℓ and ℓc. So:

ℓ' = em − m' · ℓ + emc − m' · ℓc

That's the recurrence in three lines of algebra. The first term rescales the old ℓ for the change in max; the second contributes the chunk's ℓ rescaled to the same max. Both factors are ≤ 1, so there's no overflow.

Reading the kernel like an interviewer would

A common interview question on Triton: "What's the difference between the 2-pass softmax and Flash Attention's softmax?" Answer in one breath:

You will be expected to write the recurrence on a whiteboard. It's three lines. Memorise them.

Interactive · which softmax is right for this row length?

Drag the row length. See which variant fits in a tile, what BLOCK and num_warps Triton would pick, and when you need to fall back to the online recurrence.

Softmax variant picker

The hard cliff is at the max tile size (~8192 elements). Past that, you can't materialise the row in registers and must stream.

What's next

Softmax fuses a max reduction with a sum reduction with a divide. Lesson 11 fuses a different pair: a mean (or RMS) reduction with a scale. The kernel is half as long but the pattern — stream once, fuse the stat with the scaling — is identical.