all_lessons/gpu_kernel_serving/19 · fused kernel anatomylesson 19 / 24

Anatomy of a fused kernel

A working fused RMSNorm kernel in ~30 lines of Triton, read line by line: program model, masked loads, on-chip reductions, epilogue, autotune choice. By the end you should be able to read any modern fused kernel and predict its performance class without running it.

The question this lesson answers

You know what kernels look like in the abstract (Part I) and you know what they do for serving (Part II). What does real production kernel code look like, and what are the moving parts? We use a fused RMSNorm because (a) it appears in every transformer block, (b) it touches three habits from lesson 01 (tile, fuse, coalesce), and (c) it's small enough to fit on one screen.

What RMSNorm computes

For each row x ∈ ℝD of the input:

y = x · w / sqrt(mean(x²) + ε)

Three operations chained per row: square + sum (reduction over D), divide by sqrt, multiply by per-feature weight. If you express this in PyTorch it dispatches at least 5 kernels (square, mean, add eps, rsqrt, mul) and reads x three or four times from HBM. A fused kernel reads x twice (we'll explain why two, not one, below), keeps the reduction state in registers, and writes y once.

Triton primitives, in one paragraph (just enough to read the code)

@triton.jit JIT-compiles a Python function into a kernel. Inside, tl.program_id(axis) returns the current program's id along that grid axis. tl.arange(0, BLOCK) creates a compile-time-sized vector — operations on it are SIMD across the BLOCK lanes. tl.load(ptr, mask, other) performs a coalesced load where masked-off lanes get other. tl.sum, tl.max, tl.rsqrt, tl.exp are vector primitives that compile to warp shuffles or hardware intrinsics. Anything marked tl.constexpr is known at compile time and drives code generation (tile sizes, dtypes). @triton.autotune registers multiple configs and benchmarks them once per key. num_warps is the number of warps each program uses; num_stages is the depth of the software pipeline Triton emits for memory hiding. Lesson 22 expands all of this; the line-by-line table below explains every primitive used here.

The Triton kernel, annotated

import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_D': 256}, num_warps=4),
        triton.Config({'BLOCK_D': 512}, num_warps=4),
        triton.Config({'BLOCK_D': 1024}, num_warps=8),
        triton.Config({'BLOCK_D': 2048}, num_warps=8),
    ],
    key=['D'],                       # autotune separately per feature size
)
@triton.jit
def rmsnorm_fwd(
    x_ptr, w_ptr, y_ptr,             # data pointers
    N, D,                            # rows, features
    eps,                             # numerical stability
    BLOCK_D: tl.constexpr,           # compile-time tile width
):
    row = tl.program_id(0)           # one block per row
    cols = tl.arange(0, BLOCK_D)     # vector of column indices for this block

    # accumulator in fp32 — bf16 would lose precision on the sum
    sumsq = tl.zeros((1,), dtype=tl.float32)

    # streaming pass 1: read x in chunks, accumulate sum of squares
    for off in range(0, D, BLOCK_D):
        idx  = off + cols
        mask = idx < D
        x_blk = tl.load(x_ptr + row * D + idx, mask=mask, other=0.0).to(tl.float32)
        sumsq += tl.sum(x_blk * x_blk, axis=0)

    # one scalar per row — held in a register
    inv_rms = tl.rsqrt(sumsq / D + eps)

    # streaming pass 2: write y = x * w * inv_rms
    for off in range(0, D, BLOCK_D):
        idx  = off + cols
        mask = idx < D
        x_blk = tl.load(x_ptr + row * D + idx, mask=mask, other=0.0).to(tl.float32)
        w_blk = tl.load(w_ptr + idx,           mask=mask, other=0.0).to(tl.float32)
        y_blk = (x_blk * inv_rms * w_blk).to(tl.bfloat16)
        tl.store(y_ptr + row * D + idx, y_blk, mask=mask)

Read it line by line

Line / conceptWhat it doesWhy it matters
@triton.autotune(configs, key=['D'])Compile multiple versions; benchmark on first call per D; cache winner.The right BLOCK_D depends on D and hardware. You don't pick — the autotuner does.
row = tl.program_id(0)Each program instance handles one row.Decouples problem decomposition from hardware. Triton maps instances onto SMs.
cols = tl.arange(0, BLOCK_D)Vector of column indices for this tile.BLOCK_D is compile-time → can vectorize across the row.
sumsq = tl.zeros(..., dtype=fp32)Accumulator in fp32.Sum of bf16 squares would lose precision. Critical for numerical correctness.
for off in range(0, D, BLOCK_D)Stream the row in BLOCK_D chunks.If D > BLOCK_D, we tile across columns. Trades SMEM for one extra HBM pass.
mask = idx < DMask for the last (partial) tile.Handles D not a multiple of BLOCK_D without padding.
tl.load(..., mask=mask, other=0.0)Coalesced bf16 read with predicate.Out-of-bounds lanes read 0 — safe for sumsq accumulation.
.to(tl.float32)Upcast for accumulation.Same reason as the fp32 accumulator. The cast is free in registers.
tl.sum(x*x, axis=0)Tree reduction across the BLOCK_D vector.Compiles to warp shuffles (lesson 07).
tl.rsqrt(sumsq/D + eps)Hardware reciprocal-sqrt.One instruction, ~1 cycle. Avoids two divisions.
Pass 2 reads x againStreaming compute of y.See "why two passes" callout below — trades one extra HBM read of x for not stashing a 16 KB+ row in SMEM, which would halve occupancy.
y.to(bfloat16)Downcast for storage.Bf16 output keeps memory traffic low downstream.
tl.store(..., mask=mask)Masked write.Last tile may be partial.

Why two passes, not one (the SMEM trade-off)

You might ask: why read x twice? Couldn't we load it into shared memory on the first pass and re-use it on the second? The answer is that x for one row is D · 2 bytes — at D=8192 that's 16 KB per row. SMEM on an SM is 228 KB total, shared across all concurrent blocks. Storing the full row in SMEM means at most ~14 blocks fit per SM, and most realistic kernels target far fewer warps than that anyway. Two passes over x cost one extra HBM round trip (~33 MB at N=1024) — about 10 µs at 3.35 TB/s. The alternative spends 16 KB of precious SMEM and halves occupancy. The trade goes to the streaming version. (For smaller D — say D=512 = 1 KB per row — the choice flips.)

The structure that recurs in every fused kernel

prologuemap program → tilecompute base addresses load (HBM → reg)coalesced, maskedupcast if accumulating on-chip mathelementwise + reductionno HBM round trips epilogue (HBM ← reg)downcast if neededmasked, coalesced store If a reduction crosses tiles: two passes (compute reduction, then use it) OR one pass with shared memory holding the row OR an online algorithm (FlashAttention's online softmax is the canonical example — lesson 12). Three knobs you tune: (1) BLOCK / tile size — register & SMEM footprint · (2) num_warps — occupancy vs work-per-warp · (3) num_stages — pipeline depth for memory hiding

How big are the savings? (worked example)

Take RMSNorm at N=1024 (batch·tokens), D=8192, bf16.

ImplementationHBM readsHBM writesLaunchesNotes
Naive PyTorch (5 ops)~3× N·D·2 = 50 MB~3× N·D·2 = 50 MB5Each op materializes its output. Approximate — mean produces one scalar per row, not a full tensor.
Fused (above kernel)2× N·D·2 + D·2 = 33 MB + 16 KBN·D·2 = 16 MB1Two streaming passes over x; weights (size D) are shared across rows and small enough to fit in L2.
Fused, weights pinned in L22× N·D·2 = 33 MB16 MB1Once the 16 KB weight vector is in L2 it stays — the second pass over x sees it from L2, not HBM.

Speedup is roughly the ratio of HBM bytes, since RMSNorm is bandwidth-bound at this size: ≈ 2× from 5 kernels → 1 kernel, plus 4× launch reduction. Real benchmarks land in that range.

What changes for harder kernels

The same template (prologue → load → math → epilogue) scales up. The harder kernels just stack more stages:

Numerical traps the kernel must dodge

TrapSymptomFix
Bf16 accumulatorTiny drift; sums of large vectors lose lower bits.Accumulate in fp32 (as above).
Catastrophic cancellation in softmaxNaNs at long context.Subtract running max before exp (FlashAttention).
Division by zeroNaN/Inf in rsqrt.Add eps before rsqrt.
Mixed dtypes silently downcastLoss spikes only at scale.Explicit .to(fp32) at the right edges.
Non-deterministic atomics in reductionReproducibility broken.Avoid atomics across blocks; use tree reduction inside a block.

Interactive · fused vs naive HBM traffic

Set N, D, and how many separate kernels the naive path uses. The widget shows HBM bytes for both, the resulting bandwidth-bound time, and the launch-overhead difference.

Fusion savings: bytes + launches

Bandwidth-bound assumption: time ≈ bytes / HBM. Fused does 2 reads of x + 1 read of w + 1 write of y. Naive does ops·(read+write) of full-size intermediates.

What this gives you for the next lessons

You can now read a fused kernel and pattern-match it to "I see a prologue, a load, on-chip math, an epilogue, with a reduction phase if applicable." The next lesson teaches the tools that tell you whether your kernel is hitting its potential — three different profilers for three different questions.