Anatomy of a fused kernel
A working fused RMSNorm kernel in ~30 lines of Triton, read line by line: program model, masked loads, on-chip reductions, epilogue, autotune choice. By the end you should be able to read any modern fused kernel and predict its performance class without running it.
The question this lesson answers
You know what kernels look like in the abstract (Part I) and you know what they do for serving (Part II). What does real production kernel code look like, and what are the moving parts? We use a fused RMSNorm because (a) it appears in every transformer block, (b) it touches three habits from lesson 01 (tile, fuse, coalesce), and (c) it's small enough to fit on one screen.
What RMSNorm computes
For each row x ∈ ℝD of the input:
Three operations chained per row: square + sum (reduction over D), divide by sqrt, multiply by per-feature weight. If you express this in PyTorch it dispatches at least 5 kernels (square, mean, add eps, rsqrt, mul) and reads x three or four times from HBM. A fused kernel reads x twice (we'll explain why two, not one, below), keeps the reduction state in registers, and writes y once.
Triton primitives, in one paragraph (just enough to read the code)
@triton.jit JIT-compiles a Python function into a kernel. Inside, tl.program_id(axis) returns the current program's id along that grid axis. tl.arange(0, BLOCK) creates a compile-time-sized vector — operations on it are SIMD across the BLOCK lanes. tl.load(ptr, mask, other) performs a coalesced load where masked-off lanes get other. tl.sum, tl.max, tl.rsqrt, tl.exp are vector primitives that compile to warp shuffles or hardware intrinsics. Anything marked tl.constexpr is known at compile time and drives code generation (tile sizes, dtypes). @triton.autotune registers multiple configs and benchmarks them once per key. num_warps is the number of warps each program uses; num_stages is the depth of the software pipeline Triton emits for memory hiding. Lesson 22 expands all of this; the line-by-line table below explains every primitive used here.
The Triton kernel, annotated
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_D': 256}, num_warps=4),
triton.Config({'BLOCK_D': 512}, num_warps=4),
triton.Config({'BLOCK_D': 1024}, num_warps=8),
triton.Config({'BLOCK_D': 2048}, num_warps=8),
],
key=['D'], # autotune separately per feature size
)
@triton.jit
def rmsnorm_fwd(
x_ptr, w_ptr, y_ptr, # data pointers
N, D, # rows, features
eps, # numerical stability
BLOCK_D: tl.constexpr, # compile-time tile width
):
row = tl.program_id(0) # one block per row
cols = tl.arange(0, BLOCK_D) # vector of column indices for this block
# accumulator in fp32 — bf16 would lose precision on the sum
sumsq = tl.zeros((1,), dtype=tl.float32)
# streaming pass 1: read x in chunks, accumulate sum of squares
for off in range(0, D, BLOCK_D):
idx = off + cols
mask = idx < D
x_blk = tl.load(x_ptr + row * D + idx, mask=mask, other=0.0).to(tl.float32)
sumsq += tl.sum(x_blk * x_blk, axis=0)
# one scalar per row — held in a register
inv_rms = tl.rsqrt(sumsq / D + eps)
# streaming pass 2: write y = x * w * inv_rms
for off in range(0, D, BLOCK_D):
idx = off + cols
mask = idx < D
x_blk = tl.load(x_ptr + row * D + idx, mask=mask, other=0.0).to(tl.float32)
w_blk = tl.load(w_ptr + idx, mask=mask, other=0.0).to(tl.float32)
y_blk = (x_blk * inv_rms * w_blk).to(tl.bfloat16)
tl.store(y_ptr + row * D + idx, y_blk, mask=mask)
Read it line by line
| Line / concept | What it does | Why it matters |
|---|---|---|
@triton.autotune(configs, key=['D']) | Compile multiple versions; benchmark on first call per D; cache winner. | The right BLOCK_D depends on D and hardware. You don't pick — the autotuner does. |
row = tl.program_id(0) | Each program instance handles one row. | Decouples problem decomposition from hardware. Triton maps instances onto SMs. |
cols = tl.arange(0, BLOCK_D) | Vector of column indices for this tile. | BLOCK_D is compile-time → can vectorize across the row. |
sumsq = tl.zeros(..., dtype=fp32) | Accumulator in fp32. | Sum of bf16 squares would lose precision. Critical for numerical correctness. |
for off in range(0, D, BLOCK_D) | Stream the row in BLOCK_D chunks. | If D > BLOCK_D, we tile across columns. Trades SMEM for one extra HBM pass. |
mask = idx < D | Mask for the last (partial) tile. | Handles D not a multiple of BLOCK_D without padding. |
tl.load(..., mask=mask, other=0.0) | Coalesced bf16 read with predicate. | Out-of-bounds lanes read 0 — safe for sumsq accumulation. |
.to(tl.float32) | Upcast for accumulation. | Same reason as the fp32 accumulator. The cast is free in registers. |
tl.sum(x*x, axis=0) | Tree reduction across the BLOCK_D vector. | Compiles to warp shuffles (lesson 07). |
tl.rsqrt(sumsq/D + eps) | Hardware reciprocal-sqrt. | One instruction, ~1 cycle. Avoids two divisions. |
Pass 2 reads x again | Streaming compute of y. | See "why two passes" callout below — trades one extra HBM read of x for not stashing a 16 KB+ row in SMEM, which would halve occupancy. |
y.to(bfloat16) | Downcast for storage. | Bf16 output keeps memory traffic low downstream. |
tl.store(..., mask=mask) | Masked write. | Last tile may be partial. |
Why two passes, not one (the SMEM trade-off)
You might ask: why read x twice? Couldn't we load it into shared memory on the first pass and re-use it on the second? The answer is that x for one row is D · 2 bytes — at D=8192 that's 16 KB per row. SMEM on an SM is 228 KB total, shared across all concurrent blocks. Storing the full row in SMEM means at most ~14 blocks fit per SM, and most realistic kernels target far fewer warps than that anyway. Two passes over x cost one extra HBM round trip (~33 MB at N=1024) — about 10 µs at 3.35 TB/s. The alternative spends 16 KB of precious SMEM and halves occupancy. The trade goes to the streaming version. (For smaller D — say D=512 = 1 KB per row — the choice flips.)
The structure that recurs in every fused kernel
How big are the savings? (worked example)
Take RMSNorm at N=1024 (batch·tokens), D=8192, bf16.
| Implementation | HBM reads | HBM writes | Launches | Notes |
|---|---|---|---|---|
| Naive PyTorch (5 ops) | ~3× N·D·2 = 50 MB | ~3× N·D·2 = 50 MB | 5 | Each op materializes its output. Approximate — mean produces one scalar per row, not a full tensor. |
| Fused (above kernel) | 2× N·D·2 + D·2 = 33 MB + 16 KB | N·D·2 = 16 MB | 1 | Two streaming passes over x; weights (size D) are shared across rows and small enough to fit in L2. |
| Fused, weights pinned in L2 | 2× N·D·2 = 33 MB | 16 MB | 1 | Once the 16 KB weight vector is in L2 it stays — the second pass over x sees it from L2, not HBM. |
Speedup is roughly the ratio of HBM bytes, since RMSNorm is bandwidth-bound at this size: ≈ 2× from 5 kernels → 1 kernel, plus 4× launch reduction. Real benchmarks land in that range.
What changes for harder kernels
The same template (prologue → load → math → epilogue) scales up. The harder kernels just stack more stages:
- Matmul (lesson 05 + 09): two nested tiles (M, N) over a K-loop. Inner accumulation in registers; final epilogue writes the tile.
- FlashAttention (lesson 12): Q tile (outer) and K/V tile (inner) with an online softmax in the inner loop — state that survives across iterations stays in registers.
- Producer-consumer kernels (FA3-style): two warp groups — one issues async copies, the other consumes — using Hopper's TMA. Hides memory latency behind compute.
Numerical traps the kernel must dodge
| Trap | Symptom | Fix |
|---|---|---|
| Bf16 accumulator | Tiny drift; sums of large vectors lose lower bits. | Accumulate in fp32 (as above). |
| Catastrophic cancellation in softmax | NaNs at long context. | Subtract running max before exp (FlashAttention). |
| Division by zero | NaN/Inf in rsqrt. | Add eps before rsqrt. |
| Mixed dtypes silently downcast | Loss spikes only at scale. | Explicit .to(fp32) at the right edges. |
| Non-deterministic atomics in reduction | Reproducibility broken. | Avoid atomics across blocks; use tree reduction inside a block. |
Interactive · fused vs naive HBM traffic
Set N, D, and how many separate kernels the naive path uses. The widget shows HBM bytes for both, the resulting bandwidth-bound time, and the launch-overhead difference.
What this gives you for the next lessons
You can now read a fused kernel and pattern-match it to "I see a prologue, a load, on-chip math, an epilogue, with a reduction phase if applicable." The next lesson teaches the tools that tell you whether your kernel is hitting its potential — three different profilers for three different questions.