Reductions — sum, max, softmax
A reduction collapses many values into one. Done naïvely, it's a serial chain; done well, it's a parallel tree that uses warp shuffles for the inner stage, shared memory for the block stage, and atomics for the global stage. Almost every "produce one number from many" in ML is a variant of this template.
The serial story
Sum N elements. Serial: for (i=0; i<N; i++) acc += x[i]. That's N additions, each depending on the previous. Cannot be parallelised in this form.
Reductions are associative: (a + b) + c == a + (b + c). Floating-point is technically non-associative (rounding), but for most ML purposes the small bit-level difference is acceptable. Associativity lets us reshape the computation into a tree, with depth log₂(N) and width N/2 at the leaves.
2D · log-N tree reduction, click to halve
The cleanest mental model. N=32 leaves at the top; each step pairs adjacent values and adds. After log₂(N)=5 steps a single value sits at the bottom. Click "step" and watch the active set halve. The sync counter shows how many __syncthreads() calls you'd pay in the SMEM version.
The three-level CUDA reduction
For a large array on the GPU, the standard pattern has four nested stages:
- Per-thread: each thread reads several elements and accumulates them serially in a register. ("Stride loop.")
- Per-warp: the 32 threads of a warp combine their per-thread sums using warp shuffles. Result in lane 0.
- Per-block: the warps write their lane-0 results to SMEM, one thread combines them, the block has its sum.
- Per-grid: blocks combine via global atomics or a second kernel.
Stage 1 — per-thread stride loop
// each thread reads element i, i+gridSize, i+2*gridSize, ...
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
float acc = 0;
for (int i = tid; i < N; i += stride) acc += x[i];
The "stride" pattern (also called grid-stride loop) is coalesced: at step k, threads 0..31 of one warp read elements tid+k*stride, consecutive. The total work per thread is N / (gridSize · blockSize) — choose grid + block size so each thread does enough work to amortise overhead.
Stage 2 — warp-level reduction via shuffles
The 32 threads of a warp each have a partial sum in acc. To combine them, use __shfl_down_sync:
// butterfly reduction: combine lanes 16-31 into 0-15, then 8-15 into 0-7, etc.
for (int offset = 16; offset > 0; offset /= 2)
acc += __shfl_down_sync(0xFFFFFFFF, acc, offset);
// after this, lane 0 has the warp's total
Five rounds of shuffle (offsets 16, 8, 4, 2, 1) combine all 32 lanes into lane 0. Each round halves the active range. No SMEM, no __syncthreads() — warp shuffles are very cheap (the _sync variants do enforce convergence among the masked lanes since Volta, so they aren't literally free, but the cost is on the order of a single instruction, far below SMEM-with-barrier).
Stage 3 — block-level reduction via SMEM
__shared__ float warp_sums[32]; // up to 32 warps per block
int lane = threadIdx.x % 32;
int warp = threadIdx.x / 32;
if (lane == 0) warp_sums[warp] = acc;
__syncthreads();
// first warp reduces the warp_sums
if (warp == 0) {
acc = (threadIdx.x < blockDim.x/32) ? warp_sums[lane] : 0;
for (int offset = 16; offset > 0; offset /= 2)
acc += __shfl_down_sync(0xFFFFFFFF, acc, offset);
if (lane == 0) {
// acc now holds the block's total
atomicAdd(global_sum, acc); // or write to block_sums[blockIdx.x]
}
}
Notice the structure: warp shuffles inside, SMEM for cross-warp communication, atomic at the global level. Each stage uses the cheapest synchronisation available — the previous lesson's hierarchy is doing real work here.
Animated · three-layer reduction in flight
Watch a 16-block × 4-warp × 32-lane reduction collapse layer by layer. Each warp first does its butterfly shuffle (lane 0 catches the result), then warp 0's lane 0..3 read the 4 warp-sums from SMEM, then block 0..15 issue an atomic add to a single global. Three layers of fan-in. Step or play.
Stage 4 — global combination
Two options for the across-block sum:
atomicAddto a global counter. Simple, modest contention if the kernel launches many blocks.atomicAdd(float*)is available since CC 2.0;atomicAdd(double*)requires CC 6.0+. Both are fine for ML loss / norm reductions.- Write per-block sums + relaunch a small reduction. One block does the final combine. Used when you want deterministic floating-point summation (atomic order is not deterministic).
For most ML reductions, atomicAdd on the final result is fine. For loss values that need bit-exact reproducibility, do the two-kernel version.
Softmax — reduction's most-shipped customer
Softmax over a row of length N: y[i] = exp(x[i] - max) / sum(exp(x[j] - max)). Two reductions over the same row: max, then sum-of-exp. The naïve PyTorch implementation does this in three separate kernels with three HBM round-trips per row. A fused softmax kernel does it in one pass:
__global__ void softmax_row(const float* x, float* y, int N) {
int row = blockIdx.x;
// each thread handles N/blockDim elements via stride loop
float local_max = -INFINITY;
for (int i = threadIdx.x; i < N; i += blockDim.x)
local_max = fmaxf(local_max, x[row*N + i]);
// warp reduce + block reduce → row_max
float row_max = block_reduce_max(local_max);
float local_sum = 0;
for (int i = threadIdx.x; i < N; i += blockDim.x)
local_sum += expf(x[row*N + i] - row_max);
float row_sum = block_reduce_sum(local_sum);
// write result
for (int i = threadIdx.x; i < N; i += blockDim.x)
y[row*N + i] = expf(x[row*N + i] - row_max) / row_sum;
}
One thread block per row; one pass over the row for max, one for sum-of-exp, one for the final divide. HBM round-trips: 2× the row size (one read, one write) — versus the naïve 4–6× for unfused. The lesson-16 fusion principle applied to a reduction.
The online-softmax trick (FlashAttention's heart)
The kernel above still makes two passes over the data (one for max, one for sum). FlashAttention's contribution is to do it in one pass with a running (m, ℓ) state:
When a new element arrives, you "rescale" the running normaliser to the new max. After all elements seen, ℓ is exactly Σ exp(xᵢ - m). Two reductions become one. This generalises to any running statistic with the same shape (max, mean, variance via Welford's algorithm), which is why LayerNorm, RMSNorm, and FlashAttention all share this skeleton.
The atomic-vs-tree decision
| Final stage | Latency | Throughput | Deterministic? |
|---|---|---|---|
| Atomic add | Low (1 kernel) | Throughput limited by atomic contention | No (order varies) |
| Tree (2 kernels) | Higher (2 launches) | Better at high block counts | Yes (with proper care) |
| cooperative_groups grid reduce | Best (1 kernel, grid barrier) | Limited to one wave of blocks | Yes |
For most ML kernels (loss, gradient norm, layer norm), atomics are fine. For high-precision scientific computing, the two-kernel tree is the standard answer.
Variations on a theme
Every common reduction looks like the template above with different per-thread / per-warp combining operators:
- Sum:
+ - Max:
fmaxf - Argmax: tuple of (value, index); combiner picks the max-value
- Welford mean/variance: the online formula keeps a running (n, mean, M2) triple
- Sum-of-squares (norms): per-thread accumulates
x*x, then sum-reduce - Top-k: a small heap per thread; combine via merge
All have the same skeleton: per-thread accumulate, warp-shuffle combine, SMEM block combine, atomic/tree global combine.
2D · the template, applied to softmax · LayerNorm · grad-norm
Three side-by-side reductions, all sharing the same warp → block → grid skeleton, differing only in the per-lane combiner and the final post-processing. Hover or click a lane to inspect its contribution; toggle which reduction you're watching.
Interactive · stage-by-stage time and bandwidth
Pick N. The widget breaks down per-stage cost: per-thread reads (HBM-bound), per-warp shuffle (free), per-block SMEM (low), global atomic (constant). Drag N from 1k to 1B and see the per-thread stage dominate at large N.
__syncthreads, then atomicAdd. The online variant (Welford / FlashAttention) does it in one pass. LayerNorm, softmax, gradient norms — all the same skeleton.