all_lessons / Triton kernels / lessons / 06 · reductions lesson 06 / 14

Reductions and online algorithms

Half of every transformer kernel is "reduce a tile axis". tl.sum, tl.max, tl.min, tl.cumsum do the visible work; under them is a warp-shuffle + shared-memory dance the compiler emits for you. The lesson ends with the online softmax recurrence — the trick that makes Flash Attention possible.

The DSL surface

OpSemanticsAxis
tl.sum(x, axis)Σ over axis0 or 1 (2D), 0 (1D)
tl.max(x, axis)max over axissame
tl.min(x, axis)min over axissame
tl.cumsum(x, axis)prefix sum (inclusive)same
tl.argmax(x, axis)index of maxsame
tl.reduce(x, axis, combine_fn)generic associative reductionsame

All of these take a tile (1D or 2D), reduce along the specified axis, and return a tile with that axis dropped. Example:

x = tl.load(...)                # (BM, BN)
row_sum = tl.sum(x, axis=1)     # (BM,)
col_max = tl.max(x, axis=0)     # (BN,)

What a reduction compiles to

A tile-axis reduction lowers to a two-stage hardware operation:

Stage 1: intra-warp reduction (warp shuffles, no shared memory) 32 lanes → __shfl_xor (5 steps) → 1 partial sum per warp Stage 2: inter-warp reduction (shared memory) N warp partials → SMEM → 1 final result Stage 3 (optional): broadcast result back to every lane load result from SMEM into every lane (for x/result style ops) tl.sum(x, axis=...) is 3 instructions in your source. The compiler emits all of this; you never write a __syncthreads.

The cost: roughly log2(N) shuffle steps within a warp, then one round-trip through SMEM if you have more than one warp. For a 1024-lane tile on 4 warps, that's 5 shuffles + 1 SMEM round-trip per reduction — fast.

Naive softmax in Triton

Let's apply this. Standard softmax for a row of length N:

softmax(x)i = exp(xi) / Σj exp(xj)

Done literally, this overflows: exp(x) blows up for moderately large x. The standard fix is the max-subtract trick:

softmax(x)i = exp(xi − m) / Σj exp(xj − m),   m = maxj xj

So a numerically stable softmax is three passes over the row: find m, sum exp(x − m), divide. In Triton:

@triton.jit
def softmax_row(x_ptr, y_ptr, N, BLOCK: tl.constexpr):
    row = tl.program_id(0)
    offs = tl.arange(0, BLOCK)
    mask = offs < N
    x = tl.load(x_ptr + row*N + offs, mask=mask, other=-float('inf'))
    m = tl.max(x, axis=0)                      # 1 reduction
    e = tl.exp(x - m)
    s = tl.sum(e, axis=0)                      # 1 reduction
    y = e / s
    tl.store(y_ptr + row*N + offs, y, mask=mask)

Two reductions, both compile to warp-shuffle + SMEM. One row per program. This works — and it's what most beginners write. But it has a subtle problem: it assumes the whole row fits in one tile. If the row is longer than the maximum tile size (typically 2048), you can't process it in one program.

The two-tile problem and the online trick

Suppose the row is 8192 elements and your tile is BLOCK=2048. You'd want to process it in 4 chunks. But softmax depends on a global max and a global sum. How do you fold partial computations from each chunk into a final answer?

This is where the online softmax recurrence (Milakov & Gimelshein, 2018) comes in. Given two chunks with their own running max m1, m2 and running sum-of-exponentials ℓ1, ℓ2:

m' = max(m1, m2)
ℓ' = em1−m'·ℓ1 + em2−m'·ℓ2

That's the entire recurrence. m' is the running max of the combined chunks. ℓ' is the sum-of-exp-of-(x − new-max) for the combined chunks. The clever step: each previous ℓ is rescaled by e(old m − new m), which corrects for the change in the max.

Streaming over chunks while keeping (m, ℓ) in registers chunk 1 m₁ = max, ℓ₁ = Σe^(x−m₁) chunk 2 m₂, ℓ₂ merged (online) m', ℓ' (one register pair) Key property After the loop you have one (m, ℓ) for the whole row. Final divide is e^(x − m) / ℓ. No need to materialise the entire row — you stream it through registers. Why this matters For attention, the row is N tokens long. Materialising N×N attention scores into HBM is what Flash Attention avoids. The online trick lets you stream them. Lesson 10 derives this fully; lesson 12 uses it inside the attention kernel.

Generic associative reductions with tl.reduce

What if you want a reduction that isn't sum or max — say, the (max, sum-of-exp) pair used in online softmax, as one combined reduction? Triton lets you supply the combine function:

@triton.jit
def combine_max_sum(m1, l1, m2, l2):
    m  = tl.maximum(m1, m2)
    l  = l1 * tl.exp(m1 - m) + l2 * tl.exp(m2 - m)
    return m, l

# Build per-lane (m, ℓ) tiles: each lane treated as its own one-element chunk
# whose max is its value and whose ℓ relative to that max is 1.
m_tile = x                                  # (BLOCK,)  per-lane local max
l_tile = tl.full(x.shape, 1.0, tl.float32)  # (BLOCK,)  per-lane local ℓ = e^(x-x) = 1
m, l = tl.reduce((m_tile, l_tile), axis=0, combine_fn=combine_max_sum)

The function must be associative — the compiler will apply it in a tree reduction, not left-to-right. The (max, sum) recurrence above is associative; that's what makes it work as a tile reduction. (In practice you'll often hand-write the streaming loop — lesson 12 — instead of using tl.reduce with a tuple combiner, but the API is there.)

Multi-dim reductions

For a 2D tile x[M, N]:

For attention, you reduce the K-axis (axis=1 if you laid out scores as (M, N) with N = key sequence length).

Watch out for the boundary tile

For a reduction, the other argument on the load must be the identity of the reduction:

x = tl.load(p+offs, mask=offs<N, other=0.0)        # for tl.sum
m = tl.load(p+offs, mask=offs<N, other=-float('inf')) # for tl.max

If you load with other=0.0 and then do tl.max, the boundary lanes contribute 0, which is wrong if the real data is negative (the boundary 0s will win). Use -inf for max, +inf for min, 0 for sum, 1 for product. The lesson 04 table is the law.

Interactive · watch the online recurrence converge

Stream a row through the recurrence one chunk at a time. The "true" softmax denominator is the dashed line; the running ℓ catches up to it.

Online (max, sum) recurrence

Each step processes one chunk and updates the running (m, ℓ). After N steps the result matches a 3-pass softmax — but you only ever held one register pair, not the whole row.

What's next

You've met every primitive: tl.load, tl.dot, tl.sum/tl.max, tl.reduce. Lesson 07 puts them together in the simplest possible end-to-end kernel — vector add — with the full Triton workflow: write, launch, mask, autotune, benchmark, verify.