gpu_kernel_serving / 07 · reductions lesson 7 / 17

Reductions — sum, max, softmax

A reduction collapses many values into one. Done naïvely, it's a serial chain; done well, it's a parallel tree that uses warp shuffles for the inner stage, shared memory for the block stage, and atomics for the global stage. Almost every "produce one number from many" in ML is a variant of this template.

The serial story

Sum N elements. Serial: for (i=0; i<N; i++) acc += x[i]. That's N additions, each depending on the previous. Cannot be parallelised in this form.

Reductions are associative: (a + b) + c == a + (b + c). Floating-point is technically non-associative (rounding), but for most ML purposes the small bit-level difference is acceptable. Associativity lets us reshape the computation into a tree, with depth log₂(N) and width N/2 at the leaves.

parallel reduction tree · depth log₂(N) x₀x₁ x₂x₃ x₄x₅ x₆x₇ x₀+x₁ x₂+x₃ x₄+x₅ x₆+x₇ (x₀+x₁)+(x₂+x₃) (x₄+x₅)+(x₆+x₇) sum of all 8

2D · log-N tree reduction, click to halve

The cleanest mental model. N=32 leaves at the top; each step pairs adjacent values and adds. After log₂(N)=5 steps a single value sits at the bottom. Click "step" and watch the active set halve. The sync counter shows how many __syncthreads() calls you'd pay in the SMEM version.

Binary tree reduction · log₂(N) depth
Each step: lane i reads lane i + stride, adds, writes back. Stride halves: 16, 8, 4, 2, 1. Active lanes shaded; idle ones fade.
step
0
stride
active lanes
__syncthreads() calls
0

The three-level CUDA reduction

For a large array on the GPU, the standard pattern has four nested stages:

  1. Per-thread: each thread reads several elements and accumulates them serially in a register. ("Stride loop.")
  2. Per-warp: the 32 threads of a warp combine their per-thread sums using warp shuffles. Result in lane 0.
  3. Per-block: the warps write their lane-0 results to SMEM, one thread combines them, the block has its sum.
  4. Per-grid: blocks combine via global atomics or a second kernel.

Stage 1 — per-thread stride loop

// each thread reads element i, i+gridSize, i+2*gridSize, ...
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
float acc = 0;
for (int i = tid; i < N; i += stride) acc += x[i];

The "stride" pattern (also called grid-stride loop) is coalesced: at step k, threads 0..31 of one warp read elements tid+k*stride, consecutive. The total work per thread is N / (gridSize · blockSize) — choose grid + block size so each thread does enough work to amortise overhead.

Stage 2 — warp-level reduction via shuffles

The 32 threads of a warp each have a partial sum in acc. To combine them, use __shfl_down_sync:

// butterfly reduction: combine lanes 16-31 into 0-15, then 8-15 into 0-7, etc.
for (int offset = 16; offset > 0; offset /= 2)
    acc += __shfl_down_sync(0xFFFFFFFF, acc, offset);
// after this, lane 0 has the warp's total

Five rounds of shuffle (offsets 16, 8, 4, 2, 1) combine all 32 lanes into lane 0. Each round halves the active range. No SMEM, no __syncthreads() — warp shuffles are very cheap (the _sync variants do enforce convergence among the masked lanes since Volta, so they aren't literally free, but the cost is on the order of a single instruction, far below SMEM-with-barrier).

__shfl_down_sync butterfly · 5 rounds → lane 0 holds the sum offset 16: L0 += L16 · L1 += L17 · … L15 += L31 offset 8: L0 += L8 · L1 += L9 · … L7 += L15 offsets 4, 2, 1: repeat → L0 holds the warp's sum

Stage 3 — block-level reduction via SMEM

__shared__ float warp_sums[32];   // up to 32 warps per block
int lane = threadIdx.x % 32;
int warp = threadIdx.x / 32;
if (lane == 0) warp_sums[warp] = acc;
__syncthreads();

// first warp reduces the warp_sums
if (warp == 0) {
    acc = (threadIdx.x < blockDim.x/32) ? warp_sums[lane] : 0;
    for (int offset = 16; offset > 0; offset /= 2)
        acc += __shfl_down_sync(0xFFFFFFFF, acc, offset);
    if (lane == 0) {
        // acc now holds the block's total
        atomicAdd(global_sum, acc);   // or write to block_sums[blockIdx.x]
    }
}

Notice the structure: warp shuffles inside, SMEM for cross-warp communication, atomic at the global level. Each stage uses the cheapest synchronisation available — the previous lesson's hierarchy is doing real work here.

Animated · three-layer reduction in flight

Watch a 16-block × 4-warp × 32-lane reduction collapse layer by layer. Each warp first does its butterfly shuffle (lane 0 catches the result), then warp 0's lane 0..3 read the 4 warp-sums from SMEM, then block 0..15 issue an atomic add to a single global. Three layers of fan-in. Step or play.

Three-layer reduction · warp · block · grid
Layer 1 (blue): 32 lanes → 1 per warp via shfl. Layer 2 (orange): 4 warp-sums → 1 per block via SMEM. Layer 3 (red): each block's sum → 1 global value via atomicAdd.
layer
mechanism
survivors
running global

Stage 4 — global combination

Two options for the across-block sum:

For most ML reductions, atomicAdd on the final result is fine. For loss values that need bit-exact reproducibility, do the two-kernel version.

Softmax — reduction's most-shipped customer

Softmax over a row of length N: y[i] = exp(x[i] - max) / sum(exp(x[j] - max)). Two reductions over the same row: max, then sum-of-exp. The naïve PyTorch implementation does this in three separate kernels with three HBM round-trips per row. A fused softmax kernel does it in one pass:

__global__ void softmax_row(const float* x, float* y, int N) {
    int row = blockIdx.x;
    // each thread handles N/blockDim elements via stride loop
    float local_max = -INFINITY;
    for (int i = threadIdx.x; i < N; i += blockDim.x)
        local_max = fmaxf(local_max, x[row*N + i]);
    // warp reduce + block reduce → row_max
    float row_max = block_reduce_max(local_max);

    float local_sum = 0;
    for (int i = threadIdx.x; i < N; i += blockDim.x)
        local_sum += expf(x[row*N + i] - row_max);
    float row_sum = block_reduce_sum(local_sum);

    // write result
    for (int i = threadIdx.x; i < N; i += blockDim.x)
        y[row*N + i] = expf(x[row*N + i] - row_max) / row_sum;
}

One thread block per row; one pass over the row for max, one for sum-of-exp, one for the final divide. HBM round-trips: 2× the row size (one read, one write) — versus the naïve 4–6× for unfused. The lesson-16 fusion principle applied to a reduction.

The online-softmax trick (FlashAttention's heart)

The kernel above still makes two passes over the data (one for max, one for sum). FlashAttention's contribution is to do it in one pass with a running (m, ℓ) state:

m_new = max(m_old, x),    ℓ_new = ℓ_old · exp(m_old - m_new) + exp(x - m_new)

When a new element arrives, you "rescale" the running normaliser to the new max. After all elements seen, is exactly Σ exp(xᵢ - m). Two reductions become one. This generalises to any running statistic with the same shape (max, mean, variance via Welford's algorithm), which is why LayerNorm, RMSNorm, and FlashAttention all share this skeleton.

The atomic-vs-tree decision

Final stageLatencyThroughputDeterministic?
Atomic addLow (1 kernel)Throughput limited by atomic contentionNo (order varies)
Tree (2 kernels)Higher (2 launches)Better at high block countsYes (with proper care)
cooperative_groups grid reduceBest (1 kernel, grid barrier)Limited to one wave of blocksYes

For most ML kernels (loss, gradient norm, layer norm), atomics are fine. For high-precision scientific computing, the two-kernel tree is the standard answer.

Variations on a theme

Every common reduction looks like the template above with different per-thread / per-warp combining operators:

All have the same skeleton: per-thread accumulate, warp-shuffle combine, SMEM block combine, atomic/tree global combine.

2D · the template, applied to softmax · LayerNorm · grad-norm

Three side-by-side reductions, all sharing the same warp → block → grid skeleton, differing only in the per-lane combiner and the final post-processing. Hover or click a lane to inspect its contribution; toggle which reduction you're watching.

Three reductions, one template
Same skeleton: per-lane init, pairwise combine, final blend. Only the operator changes. LogSumExp tracks (m, ℓ); LayerNorm tracks (μ, σ²); grad-norm tracks Σx².
per-lane state
combiner
post-process
final value

Interactive · stage-by-stage time and bandwidth

Pick N. The widget breaks down per-stage cost: per-thread reads (HBM-bound), per-warp shuffle (free), per-block SMEM (low), global atomic (constant). Drag N from 1k to 1B and see the per-thread stage dominate at large N.

Reduction · time breakdown by stage
Toy model. Stage-1 time = N · 4 bytes / HBM BW. Other stages add up to microseconds, regardless of N.
stage 1 (per-thread, HBM)
stage 2 (warp shuffle)
stage 3 (SMEM, block)
stage 4 (atomic)
Takeaway
Reductions are trees. The kernel template is per-thread accumulate → warp-shuffle combine → SMEM block combine → global atomic. Each stage uses the cheapest sync mechanism that fits: registers, then warp shuffles, then __syncthreads, then atomicAdd. The online variant (Welford / FlashAttention) does it in one pass. LayerNorm, softmax, gradient norms — all the same skeleton.