all_lessons / Triton kernels / lessons / 11 · rmsnorm lesson 11 / 14

RMSNorm — fused stat + scale

Per-row mean-square reduction, rsqrt, multiply by a learned weight. Three ops in math, but in Triton it's one kernel — and the pattern generalises to any normalisation: LayerNorm, GroupNorm, anything that's "compute a tile-axis statistic, then rescale by it".

The math

RMSNorm(x)i = xi · wi / √( (1/N) · Σj xj2 + ε )

One reduction (the sum of squares), one rsqrt, one multiply. The weight w is a learned per-channel scale.

Eager PyTorch evaluates this as ~5 launches:

  1. x * x — elementwise square
  2. .mean(-1) — row-wise reduction
  3. + eps, rsqrt — two small kernels
  4. x * inv_rms[:, None] — broadcast multiply
  5. * w — broadcast multiply by weight

Each pass reads or writes x from HBM. Total: ~3 reads of x, 1 write of y. Triton fuses to: 1 read of x, 1 read of w, 1 write of y. Roughly 3× less bandwidth.

The kernel

@triton.jit
def rmsnorm_kernel(
    x_ptr, w_ptr, y_ptr,
    stride_x_row, stride_y_row,
    N, eps,
    BLOCK: tl.constexpr,
):
    row  = tl.program_id(0)
    offs = tl.arange(0, BLOCK)
    mask = offs < N

    x_row = x_ptr + row * stride_x_row
    y_row = y_ptr + row * stride_y_row

    # 1 · load row + weight (weight is shared across rows)
    x = tl.load(x_row + offs, mask=mask, other=0.0).to(tl.float32)
    w = tl.load(w_ptr + offs, mask=mask, other=0.0).to(tl.float32)

    # 2 · reduction: mean of squares
    var  = tl.sum(x * x, axis=0) / N

    # 3 · rsqrt + scale + weight, all in registers
    rstd = 1.0 / tl.sqrt(var + eps)
    y    = (x * rstd) * w

    # 4 · store, cast to output dtype
    tl.store(y_row + offs, y.to(tl.bfloat16), mask=mask)

The whole thing is one pass over the row. The reduction is one tl.sum (warp shuffle + SMEM, lesson 06). Everything else is per-lane register math.

Why we promote to fp32 inside

The accumulator for x*x needs more precision than bf16 — long rows accumulate too much error otherwise. The pattern is identical to the fp32 accumulator inside tl.dot (lesson 05): compute statistics in fp32, store the result in the input dtype.

Common bug
Computing var in bf16 looks fine on small models and explodes on long sequence lengths. The mean-square of a 4K-element row in bf16 has ~5 bits of precision — enough to be silently wrong. Always .to(tl.float32) before the reduction.

The wrapper

def rmsnorm(x, w, eps=1e-6):
    assert x.is_cuda and w.is_cuda
    M, N = x.shape
    y = torch.empty_like(x)
    BLOCK = triton.next_power_of_2(N)
    num_warps = 4 if BLOCK < 2048 else (8 if BLOCK < 4096 else 16)
    rmsnorm_kernel[(M,)](
        x, w, y,
        x.stride(0), y.stride(0),
        N, eps,
        BLOCK=BLOCK, num_warps=num_warps,
    )
    return y

One program per row. Same shape as softmax (lesson 10) — the row-per-program pattern is the canonical kernel for any per-row reduction.

LayerNorm — same kernel, two statistics

LayerNorm subtracts the mean before computing variance. Two reductions instead of one:

x_f32 = x.to(tl.float32)
mean  = tl.sum(x_f32, axis=0) / N            # reduction 1
xc    = x_f32 - mean
var   = tl.sum(xc * xc, axis=0) / N           # reduction 2
rstd  = 1.0 / tl.sqrt(var + eps)
y     = xc * rstd * w + b                     # broadcast scale + bias
tl.store(...)

Two warp-shuffle reductions instead of one — about 1.3× as expensive as RMSNorm. RMSNorm exists in modern transformers (LLaMA, Mistral, Qwen) specifically because it's faster.

The fusion-with-residual trick

In a transformer block the actual pattern is:

y = norm(x + residual)         # add then norm
# or
x = x + residual; y = norm(x); residual = x   # update residual

You can fuse the residual add into the norm kernel — load x and residual, add, then proceed with the variance reduction. One fewer HBM read of the post-residual value. This is a 10–15% kernel speedup on long-context transformer blocks.

x = tl.load(...) + tl.load(...)        # x = x + residual
# rest of norm proceeds as above

Backward — the only norm with a tricky backward

Forward RMSNorm is one row reduction. Backward needs two row reductions: one over the gradient times w, one over the gradient times w times x. We won't derive them here — lesson 14 shows the full torch.autograd.Function wrapping forward and backward Triton kernels. Note that this is one of the operations where having a hand-fused kernel really matters: PyTorch's eager backward is ~10 launches.

Why the row-per-program pattern is so common

Look at the kernel shape:

M programs · each owns one row · all reductions are intra-row row 0 (BLOCK lanes) — program 0 row 1 — program 1 row 2 — program 2 No cross-program communication. No atomics. Grid scales linearly with row count. This pattern is the right one for: softmax, RMSNorm, LayerNorm, any per-row reduce-then-scale.

The pattern is so common it's worth a name: row-per-program reduction kernel. Recognise it in code reviews, and you'll see it everywhere.

Interactive · eager PyTorch vs fused Triton on RMSNorm

Pick (rows, columns). See the eager HBM traffic vs the fused HBM traffic and the predicted wall-clock on an H100.

RMSNorm fusion savings

Eager: 3 reads of x + 1 read of w + 1 write of y = 4 × Mx2 + Nx2 bytes. Fused: 1 read + 1 read + 1 write = 2 × Mx2 + Nx2 bytes.

What's next

You've built five kernels. Lesson 12 is the synthesis: Flash Attention combines lesson 09's matmul tile loop, lesson 10's online softmax recurrence, and lesson 11's row-per-program structure into one kernel that never materialises the full attention matrix. Every primitive you've met shows up.