RMSNorm — fused stat + scale
Per-row mean-square reduction, rsqrt, multiply by a learned weight. Three ops in math, but in Triton it's one kernel — and the pattern generalises to any normalisation: LayerNorm, GroupNorm, anything that's "compute a tile-axis statistic, then rescale by it".
The math
RMSNorm(x)i = xi · wi / √( (1/N) · Σj xj2 + ε )
One reduction (the sum of squares), one rsqrt, one multiply. The weight w is a learned per-channel scale.
Eager PyTorch evaluates this as ~5 launches:
x * x— elementwise square.mean(-1)— row-wise reduction+ eps,rsqrt— two small kernelsx * inv_rms[:, None]— broadcast multiply* w— broadcast multiply by weight
Each pass reads or writes x from HBM. Total: ~3 reads of x, 1 write of y. Triton fuses to: 1 read of x, 1 read of w, 1 write of y. Roughly 3× less bandwidth.
The kernel
@triton.jit
def rmsnorm_kernel(
x_ptr, w_ptr, y_ptr,
stride_x_row, stride_y_row,
N, eps,
BLOCK: tl.constexpr,
):
row = tl.program_id(0)
offs = tl.arange(0, BLOCK)
mask = offs < N
x_row = x_ptr + row * stride_x_row
y_row = y_ptr + row * stride_y_row
# 1 · load row + weight (weight is shared across rows)
x = tl.load(x_row + offs, mask=mask, other=0.0).to(tl.float32)
w = tl.load(w_ptr + offs, mask=mask, other=0.0).to(tl.float32)
# 2 · reduction: mean of squares
var = tl.sum(x * x, axis=0) / N
# 3 · rsqrt + scale + weight, all in registers
rstd = 1.0 / tl.sqrt(var + eps)
y = (x * rstd) * w
# 4 · store, cast to output dtype
tl.store(y_row + offs, y.to(tl.bfloat16), mask=mask)
The whole thing is one pass over the row. The reduction is one tl.sum (warp shuffle + SMEM, lesson 06). Everything else is per-lane register math.
Why we promote to fp32 inside
The accumulator for x*x needs more precision than bf16 — long rows accumulate too much error otherwise. The pattern is identical to the fp32 accumulator inside tl.dot (lesson 05): compute statistics in fp32, store the result in the input dtype.
var in bf16 looks fine on small models and explodes on long sequence lengths. The mean-square of a 4K-element row in bf16 has ~5 bits of precision — enough to be silently wrong. Always .to(tl.float32) before the reduction.
The wrapper
def rmsnorm(x, w, eps=1e-6):
assert x.is_cuda and w.is_cuda
M, N = x.shape
y = torch.empty_like(x)
BLOCK = triton.next_power_of_2(N)
num_warps = 4 if BLOCK < 2048 else (8 if BLOCK < 4096 else 16)
rmsnorm_kernel[(M,)](
x, w, y,
x.stride(0), y.stride(0),
N, eps,
BLOCK=BLOCK, num_warps=num_warps,
)
return y
One program per row. Same shape as softmax (lesson 10) — the row-per-program pattern is the canonical kernel for any per-row reduction.
LayerNorm — same kernel, two statistics
LayerNorm subtracts the mean before computing variance. Two reductions instead of one:
x_f32 = x.to(tl.float32)
mean = tl.sum(x_f32, axis=0) / N # reduction 1
xc = x_f32 - mean
var = tl.sum(xc * xc, axis=0) / N # reduction 2
rstd = 1.0 / tl.sqrt(var + eps)
y = xc * rstd * w + b # broadcast scale + bias
tl.store(...)
Two warp-shuffle reductions instead of one — about 1.3× as expensive as RMSNorm. RMSNorm exists in modern transformers (LLaMA, Mistral, Qwen) specifically because it's faster.
The fusion-with-residual trick
In a transformer block the actual pattern is:
y = norm(x + residual) # add then norm
# or
x = x + residual; y = norm(x); residual = x # update residual
You can fuse the residual add into the norm kernel — load x and residual, add, then proceed with the variance reduction. One fewer HBM read of the post-residual value. This is a 10–15% kernel speedup on long-context transformer blocks.
x = tl.load(...) + tl.load(...) # x = x + residual
# rest of norm proceeds as above
Backward — the only norm with a tricky backward
Forward RMSNorm is one row reduction. Backward needs two row reductions: one over the gradient times w, one over the gradient times w times x. We won't derive them here — lesson 14 shows the full torch.autograd.Function wrapping forward and backward Triton kernels. Note that this is one of the operations where having a hand-fused kernel really matters: PyTorch's eager backward is ~10 launches.
Why the row-per-program pattern is so common
Look at the kernel shape:
The pattern is so common it's worth a name: row-per-program reduction kernel. Recognise it in code reviews, and you'll see it everywhere.
Interactive · eager PyTorch vs fused Triton on RMSNorm
Pick (rows, columns). See the eager HBM traffic vs the fused HBM traffic and the predicted wall-clock on an H100.
What's next
You've built five kernels. Lesson 12 is the synthesis: Flash Attention combines lesson 09's matmul tile loop, lesson 10's online softmax recurrence, and lesson 11's row-per-program structure into one kernel that never materialises the full attention matrix. Every primitive you've met shows up.