BatchNorm, LayerNorm, RMSNorm, GroupNorm
All four are the same recipe — subtract a mean, divide by a std, scale-and-shift — applied along different axes. The axis choice and the operating regime (train vs eval, batch=1, sequence-length variable) determines which is correct. Get it wrong and your model is silently broken.
Why normalize at all
Before the 2015 paper that introduced BatchNorm, training deep networks meant fighting what was then called internal covariate shift: each layer's output distribution moves as upstream parameters update, so the downstream layer is forever chasing a moving target. (Santurkar et al. 2018 later argued the dominant effect is actually loss-landscape smoothing, but the practical recipe stands.) The fix is to fix the moments. Concretely:
where μ, σ² are the mean and variance computed across some axes of x, and γ, β are learned per-channel affine parameters that let the model "un-normalize" if the original scale was meaningful. The interesting question is always: which axes do you compute μ and σ² over?
The axis cheat sheet
Take a 4D tensor X ∈ ℝ^{B × C × H × W} (convnet) or 3D X ∈ ℝ^{B × L × D} (transformer).
| Norm | Axes of normalisation (convnet) | Axes (transformer) | Per-sample stats? |
|---|---|---|---|
| BatchNorm | B, H, W → one (μ, σ) per channel | B, L → one (μ, σ) per feature | No — averaged across batch |
| LayerNorm | C, H, W → one (μ, σ) per sample | D → one (μ, σ) per token | Yes |
| InstanceNorm | H, W → one (μ, σ) per (sample, channel) | L → one (μ, σ) per (sample, feature) | Yes |
| GroupNorm | C/G, H, W → one (μ, σ) per (sample, group) | D/G → one (μ, σ) per (sample, group) | Yes |
| RMSNorm | same as LN axis, but no μ subtraction | D, no μ | Yes |
The difference is purely "which axes are reduced". Same scale-shift mechanic, completely different downstream consequences.
BatchNorm — the original, and its serving headache
The recipe at training time:
One (μ_c, σ_c) per channel, averaged over batch and spatial positions. At test time, you can't average over a batch (you might have batch size 1). So BN keeps a running EMA of μ, σ² during training and uses those at eval:
This train/eval mismatch is BN's defining issue. Three failure modes that ship to production:
- Forgetting
model.eval(). You serve in training mode → BN uses the batch statistics of whatever single inference request you have → wildly wrong outputs. - Tiny batches. BN's batch stats become high-variance at small B. Distributed-data-parallel splits a batch across N ranks; per-rank batch can be tiny. Fix: SyncBN (all-reduce the per-channel sums across ranks), at the cost of a collective every BN layer.
- Sequence models. BatchNorm across tokens conflates information across positions in a way that breaks autoregressive generation (you'd leak future tokens into past predictions). This is the entire reason transformers don't use BatchNorm.
LayerNorm — what the transformer chose
LN computes (μ, σ) across the feature axis of each individual token:
Three things that LN gets right for transformers:
- No batch dimension involved. Train/eval semantics are identical. Batch size 1 works. Variable sequence lengths work. Distributed training doesn't need a collective inside the norm.
- Per-token normalisation. Each token gets the same statistical treatment regardless of position or batch. This matches the transformer's "each token is processed identically" inductive bias.
- The affine parameters γ, β are per-feature, not per-token. The model learns a per-channel scale-and-shift, just like BN.
RMSNorm — the simplification that ate the LLM stack
RMSNorm (Zhang & Sennrich, 2019) strips the centering step:
That's the entire change: no μ, no β (bias). Why does this work?
- Cheap. One pass over x instead of two (compute mean, then center, then variance). For a (B, L, D) tensor in a 70B model with L = 8192, this is real memory bandwidth saved.
- Empirically identical. Zhang & Sennrich showed that for transformer pretraining, RMSNorm matches LN's final loss within noise. The centering step is doing essentially no work.
- Numerical stability in bf16. Removing the (x − μ) subtraction removes a catastrophic-cancellation site. Important for fp8/bf16 training.
Every major LLM since LLaMA-2 uses RMSNorm. The interview signal is being able to say "the centering step is empirically unnecessary; the RMS scaling captures everything that matters."
Pre-norm vs post-norm — the choice that decides whether your model trains
The original "Attention Is All You Need" used post-norm:
Pre-norm wins for deep transformers because:
- Gradient flow through the residual is clean. Backprop through x_l + f(x_l) contributes 1 + f'(x_l) to ∂L/∂x_l. The 1 is the identity path. Pre-norm keeps this 1 unchanged; post-norm scales the residual by LN's derivative.
- Initialisation is more forgiving. At init, each block does almost nothing (output ≈ 0), so pre-norm's residual stream x_l + 0 ≈ x_l stays clean. Post-norm immediately re-normalises and discards the initialisation's scale.
- Pre-norm trains 100+ layers without warmup tricks. Post-norm requires careful warmup and gradient clipping to reach the same depth.
The trade-off: post-norm's final residual stream has bounded magnitude (each layer renormalises), so the output is more numerically stable. Pre-norm's residual stream grows in scale with depth — for very deep models you need a final LN before the LM head to keep the output dlogits bounded. Every modern LLM (LLaMA, Mistral, Qwen) is pre-norm plus a final LN.
The numerical-stability layer — what every modern transformer adds
| Component | Why |
|---|---|
| RMSNorm in fp32, even when activations are bf16 | The reduction Σ x_d² can overflow in bf16 for large D and large activations. Compute RMS in fp32, then return to bf16 for the elementwise divide. |
| ε = 1e-6 or 1e-8 | Smaller ε → more aggressive normalisation. 1e-5 (default) leaks under fp8. 1e-6 is the LLaMA default. |
| γ initialised to 1 | So the norm initially passes through. If γ is initialised differently (e.g., 1/√depth in LayerScale), the block starts as a no-op and learns up. |
| Final LN before LM head | The residual stream grows with depth in pre-norm. A final norm bounds the logit magnitude. Without it, output softmax saturates → tiny gradient → stuck. |
GroupNorm — the convnet's escape hatch
GroupNorm (Wu & He, 2018) sits between LayerNorm (one group = all channels) and InstanceNorm (one group = one channel). For a (B, C, H, W) tensor with G groups:
One (μ, σ) per (sample, group). Why it exists: when batch size is small (object detection, segmentation, video), BN's per-channel stats are noisy. GroupNorm gives you enough channels-per-group for stable stats without averaging across the batch. Standard G = 32.
Interactive · pick the right norm for the regime
The interview probes — what each norm is doing under the hood
- What does γ in BN actually do? γ is the per-channel scale. Without γ, BN's output has unit variance — but maybe the original input had useful magnitude information (e.g., learned attention weights). γ is how the network "undoes" the norm where it hurts and "applies" it where it helps. β is the same for the mean.
- Why does BN add regularisation? Because each sample sees the batch's mean and variance, two passes over the same network with different batches give different outputs for the same sample. This induces noise — empirically equivalent to ~0.05 dropout. Removing BN often requires adding explicit dropout to recover the regularisation.
- Why does LN reduce to "divide by RMS" if you remove β? The learnable affine pair γ, β can absorb whatever centering the network finds useful — there's no expressive gain from doing the centering inside the norm itself. Empirically, residual-stream pre-activations are nearly zero-mean by symmetry (a residual sum of approximately-symmetric block outputs), so the explicit μ subtraction contributes little. RMSNorm just removes the redundancy.
- Can you train with no norm at all? Yes — ReZero (Bachlechner et al. 2021) initialises each residual block's output multiplier α to 0, so at init the network is identity. The model learns α up as training proceeds. Eliminates norm entirely, modest accuracy hit. Used in some niche models, not mainstream.
- What's the FLOP cost of LN? For a (B, L, D) tensor: one pass for μ (B·L·D adds), one for σ² (B·L·D MACs), one elementwise scale-shift (B·L·D MACs). Total: ~3·B·L·D FLOPs. Compare to a linear layer (B·L·D·D' MACs): LN is negligible unless D is tiny. But LN is memory-bandwidth bound — three full reads/writes of the tensor — so it dominates kernel time in small models or fused kernels (FlashAttention fuses LN to amortise).
Where things go subtly wrong
| Bug | Symptom | Diagnosis |
|---|---|---|
| BN running stats drift | Model trains fine, evaluates poorly. Switching to train() at eval time "fixes" it. |
Running EMA isn't tracking the actual distribution (e.g., fine-tuning on different data). Fix: recompute running stats from training data, or unfreeze BN. |
| LN in mixed precision overflows | Loss → NaN after a few thousand steps; gradient norms show one layer's norm has blown up. | Compute LN's variance reduction in fp32 even in bf16/fp16 nets. Most frameworks do this by default; some custom kernels don't. |
| Wrong axis specified | Training works, but accuracy is much worse than expected. | BN on a transformer (normalising across tokens) leaks future tokens into the past. LN on the wrong axis of a CNN gives per-spatial-pixel norms (rare and wrong). |
| Norm before residual | Pre-norm written as x_l + LN(Block(x_l)) instead of x_l + Block(LN(x_l)). | The first form means the residual stream gets the unnormalised block output — eventually blows up. The second normalises the block input, which is the correct pre-norm pattern. |
| γ, β shape mismatch | Loss explodes or stays constant. | γ, β must broadcast over the reduced axes. For BN on (B, C, H, W) with channel-wise norm, γ has shape (C,) and is broadcast over (B, 1, H, W). Easy to mis-shape. |
Interview prompts you should be ready for
- "Why is BatchNorm rare in transformers?" (BN normalises across the batch and time axes. For an autoregressive LM, the batch contains sequences of different lengths and the time axis includes future positions you mustn't see at inference. The train/eval running-stats mismatch is also bad for variable-length inputs. LN — per-token, per-sample — sidesteps both issues.)
- "Why does LLaMA use RMSNorm instead of LN?" (Three reasons: (1) cheaper — one reduction pass instead of two; (2) empirically matches LN's loss within noise on autoregressive pretraining; (3) more bf16-friendly because removing the centering step removes a catastrophic-cancellation site. The cost is one removed degree of freedom (β bias), which doesn't matter for high-D residual streams.)
- "You trained a model with BN and the test accuracy is way lower than val. What happened?" (Almost always: forgot model.eval() — BN is using batch statistics from one test sample. Or: running stats drifted because the val and test data have different distributions. Verify by toggling
eval()and watching the gap.) - "How does pre-norm vs post-norm change gradient flow?" (Post-norm: gradient must pass through LN's derivative every layer; for L layers, the product of LN derivatives can vanish or explode. Pre-norm: gradient passes through identity in the residual path → flows cleanly; the only contribution is from the block.)
- "What's the memory cost of LN vs RMSNorm?" (Both store γ (D params) and an EMA-free state. LN stores β (D params); RMSNorm doesn't. Difference is trivial. The bigger cost is intermediate activations — LN must save μ AND σ for backward; RMSNorm just RMS. ~30% less activation memory for the norm itself.)
- "Implement LN in pseudocode for a (B, L, D) tensor." (mean = X.mean(-1, keepdim=True); var = X.var(-1, keepdim=True, unbiased=False); x_norm = (X - mean) / sqrt(var + eps); y = x_norm * gamma + beta. Three lines. Trip wires: keepdim=True for broadcast; unbiased=False matches the training-time biased estimator.)
- "Your detection model is using batch size 2 per GPU because images are huge. Which norm?" (GroupNorm with G=32. BN with B=2 has high-variance stats; LayerNorm with feature axis = all channels conflates information; GroupNorm gives enough channels per group for stable stats without averaging across samples.)