Reductions and online algorithms
Half of every transformer kernel is "reduce a tile axis". tl.sum, tl.max, tl.min, tl.cumsum do the visible work; under them is a warp-shuffle + shared-memory dance the compiler emits for you. The lesson ends with the online softmax recurrence — the trick that makes Flash Attention possible.
The DSL surface
| Op | Semantics | Axis |
|---|---|---|
tl.sum(x, axis) | Σ over axis | 0 or 1 (2D), 0 (1D) |
tl.max(x, axis) | max over axis | same |
tl.min(x, axis) | min over axis | same |
tl.cumsum(x, axis) | prefix sum (inclusive) | same |
tl.argmax(x, axis) | index of max | same |
tl.reduce(x, axis, combine_fn) | generic associative reduction | same |
All of these take a tile (1D or 2D), reduce along the specified axis, and return a tile with that axis dropped. Example:
x = tl.load(...) # (BM, BN)
row_sum = tl.sum(x, axis=1) # (BM,)
col_max = tl.max(x, axis=0) # (BN,)
What a reduction compiles to
A tile-axis reduction lowers to a two-stage hardware operation:
The cost: roughly log2(N) shuffle steps within a warp, then one round-trip through SMEM if you have more than one warp. For a 1024-lane tile on 4 warps, that's 5 shuffles + 1 SMEM round-trip per reduction — fast.
Naive softmax in Triton
Let's apply this. Standard softmax for a row of length N:
softmax(x)i = exp(xi) / Σj exp(xj)
Done literally, this overflows: exp(x) blows up for moderately large x. The standard fix is the max-subtract trick:
softmax(x)i = exp(xi − m) / Σj exp(xj − m), m = maxj xj
So a numerically stable softmax is three passes over the row: find m, sum exp(x − m), divide. In Triton:
@triton.jit
def softmax_row(x_ptr, y_ptr, N, BLOCK: tl.constexpr):
row = tl.program_id(0)
offs = tl.arange(0, BLOCK)
mask = offs < N
x = tl.load(x_ptr + row*N + offs, mask=mask, other=-float('inf'))
m = tl.max(x, axis=0) # 1 reduction
e = tl.exp(x - m)
s = tl.sum(e, axis=0) # 1 reduction
y = e / s
tl.store(y_ptr + row*N + offs, y, mask=mask)
Two reductions, both compile to warp-shuffle + SMEM. One row per program. This works — and it's what most beginners write. But it has a subtle problem: it assumes the whole row fits in one tile. If the row is longer than the maximum tile size (typically 2048), you can't process it in one program.
The two-tile problem and the online trick
Suppose the row is 8192 elements and your tile is BLOCK=2048. You'd want to process it in 4 chunks. But softmax depends on a global max and a global sum. How do you fold partial computations from each chunk into a final answer?
This is where the online softmax recurrence (Milakov & Gimelshein, 2018) comes in. Given two chunks with their own running max m1, m2 and running sum-of-exponentials ℓ1, ℓ2:
m' = max(m1, m2)
ℓ' = em1−m'·ℓ1 + em2−m'·ℓ2
That's the entire recurrence. m' is the running max of the combined chunks. ℓ' is the sum-of-exp-of-(x − new-max) for the combined chunks. The clever step: each previous ℓ is rescaled by e(old m − new m), which corrects for the change in the max.
Generic associative reductions with tl.reduce
What if you want a reduction that isn't sum or max — say, the (max, sum-of-exp) pair used in online softmax, as one combined reduction? Triton lets you supply the combine function:
@triton.jit
def combine_max_sum(m1, l1, m2, l2):
m = tl.maximum(m1, m2)
l = l1 * tl.exp(m1 - m) + l2 * tl.exp(m2 - m)
return m, l
# Build per-lane (m, ℓ) tiles: each lane treated as its own one-element chunk
# whose max is its value and whose ℓ relative to that max is 1.
m_tile = x # (BLOCK,) per-lane local max
l_tile = tl.full(x.shape, 1.0, tl.float32) # (BLOCK,) per-lane local ℓ = e^(x-x) = 1
m, l = tl.reduce((m_tile, l_tile), axis=0, combine_fn=combine_max_sum)
The function must be associative — the compiler will apply it in a tree reduction, not left-to-right. The (max, sum) recurrence above is associative; that's what makes it work as a tile reduction. (In practice you'll often hand-write the streaming loop — lesson 12 — instead of using tl.reduce with a tuple combiner, but the API is there.)
Multi-dim reductions
For a 2D tile x[M, N]:
tl.sum(x, axis=0)reduces rows →(N,)tl.sum(x, axis=1)reduces columns →(M,)tl.sum(x)reduces both → scalar
For attention, you reduce the K-axis (axis=1 if you laid out scores as (M, N) with N = key sequence length).
Watch out for the boundary tile
For a reduction, the other argument on the load must be the identity of the reduction:
x = tl.load(p+offs, mask=offs<N, other=0.0) # for tl.sum
m = tl.load(p+offs, mask=offs<N, other=-float('inf')) # for tl.max
If you load with other=0.0 and then do tl.max, the boundary lanes contribute 0, which is wrong if the real data is negative (the boundary 0s will win). Use -inf for max, +inf for min, 0 for sum, 1 for product. The lesson 04 table is the law.
Interactive · watch the online recurrence converge
Stream a row through the recurrence one chunk at a time. The "true" softmax denominator is the dashed line; the running ℓ catches up to it.
What's next
You've met every primitive: tl.load, tl.dot, tl.sum/tl.max, tl.reduce. Lesson 07 puts them together in the simplest possible end-to-end kernel — vector add — with the full Triton workflow: write, launch, mask, autotune, benchmark, verify.