Initialization & gradient stability
Until 2010, training deep networks was a lottery. Then a one-line insight — "preserve variance through layers" — turned a research mystery into engineering. The same logic underlies Kaiming init, residuals, LayerNorm, and μP.
The problem in one diagram
Forward through a deep linear net y = W_L W_{L-1} … W_1 x. If each W_l is initialised so that ‖W_l x‖ ≈ 2‖x‖, then by layer L: ‖y‖ ≈ 2^L ‖x‖. For L=80, that's 2^{80} ≈ 10^{24}. Activations blow up, loss is NaN. Conversely, if each ‖W_l x‖ ≈ 0.5 ‖x‖, by L=80, activations are 10^{-24}. Gradients vanish. The init must preserve magnitudes.
Derive Xavier (Glorot) init
Take a linear layer y = W x with x ∈ ℝ^{d_in}, W ∈ ℝ^{d_out × d_in}. Assume x's entries are iid with variance 1, and W's entries are iid with zero mean and variance σ_W².
For Var(y) = Var(x) = 1, choose σ_W² = 1/d_in. This is Xavier (forward).
Symmetric argument on backprop: ∂L/∂x = W^⊤ ∂L/∂y. Var(∂L/∂x_j) = d_out · σ_W². For variance preservation backward, σ_W² = 1/d_out.
Glorot's compromise: σ_W² = 2/(d_in + d_out). Equivalently, sample from U[-√(6/(d_in+d_out)), +√(6/(d_in+d_out))]. This is the "Xavier" init in most frameworks.
Kaiming (He) init — accounting for ReLU
Xavier assumed a linear layer. ReLU kills half the inputs (those with z < 0). The right derivation tracks the second moment, not the variance (ReLU's output has a positive mean, so Var(ReLU(z)) ≠ E[ReLU(z)²]):
So if we want E[y²] preserved through linear-then-ReLU, we double the weight variance: σ_W² = 2/d_in. This is Kaiming init (He et al. 2015). Two flavours:
- fan_in (default): σ_W² = 2/d_in. Preserves forward variance.
- fan_out: σ_W² = 2/d_out. Preserves backward variance.
For other activations:
| Activation | Gain | σ_W² |
|---|---|---|
| Linear | 1 | 1/d_in (Xavier) |
| tanh | 5/3 | (5/3)²/d_in |
| ReLU | √2 | 2/d_in (Kaiming) |
| GeLU / SiLU | ≈√2 | ~2/d_in |
| Sigmoid | 1 (small effect) | ~1/d_in |
Why residuals "rescue" deep training
Even with Kaiming init, gradients through L layers compound: each Jacobian's eigenvalues distribute around 1, but products of random matrices have eigenvalue distributions that drift away from 1 at depth. Empirically, plain feedforward nets become hard to train past ~30 layers regardless of init.
The residual connection x_{l+1} = x_l + Block(x_l) changes the geometry. The backward pass:
The I term ensures the gradient at layer l contains a copy of the gradient at layer l+1, unscaled. Even if the block's Jacobian is small, the gradient flows through the identity. The identity path provides a non-vanishing lower bound on per-layer gradient magnitude — vanishing is much harder, though not strictly impossible at extreme depth (Bachlechner's ReZero motivation).
Combined with LayerNorm (which keeps activations at unit scale), this is why transformers train at L=80 without exotic init.
The output-layer trick — zero init
For very deep models, even residual + norm can struggle at init: each block does some random thing, and the sum across 80 blocks is large. Modern fix: initialise the final projection of each block to zero. Then at step 0, every block outputs 0, and the residual stream is identity. The model learns up from there.
Variants:
- ReZero (Bachlechner 2021): multiply each block output by a learnable α_l, initialised to 0. x_{l+1} = x_l + α_l · Block(x_l). Eliminates the need for LayerNorm; modest accuracy hit.
- LayerScale (Touvron 2021, used in CaiT, ConvNeXt): per-channel learnable scale on block output, initialised to 1e-4. Helps stability for very deep ViTs.
- Zero-init final linear: in transformers, initialise the FFN's W₃ (down projection) and attention's W_O (output projection) to zero. Each block starts as a no-op.
μP — the init that lets you sweep LR at small scale
From the scaling laws lesson: standard parameterisation has the optimal LR drift with width. μP fixes this with rescaled init and rescaled per-layer LR. The key rescalings:
| Layer | Standard init | μP init | μP LR (relative) |
|---|---|---|---|
| Embedding | O(1) | O(1) | 1 |
| Hidden layers | O(1/√d) | O(1/√d) | 1/d |
| LM head (readout) | O(1/√d) | O(1/d) | 1 |
With these scalings, the optimal LR found at width 256 transfers to width 8192. This is μTransfer — the only practical way to do hyperparameter sweeps for frontier LLMs.
Loss spikes — diagnosing one in a real run
You're training a 70B model. At step 4283, loss spikes from 2.1 to 87, then NaN. Walk through:
- Check gradient norm. Did it explode? If yes — outlier batch (look at the batch's input). If no — numerical issue downstream.
- Check activation norms per layer. Is one layer's pre-activation orders of magnitude larger than others? Possible: dead init in that layer, or accumulated drift from earlier issue.
- Check learning rate. Is it currently at peak? Spikes often happen right after warmup ends.
- Check optimizer state. Has any v̂ dropped to a value that's making divisions explosive?
- Mitigations: (a) gradient clipping at ‖g‖ ≤ 1.0; (b) skip the bad batch and continue; (c) reduce peak LR; (d) lower β₂ on Adam to make the optimizer more responsive.
Industry practice: save checkpoints frequently, resume from before the spike, skip the offending batch range. Don't try to "train through" a spike; the optimizer state is usually contaminated.
Interactive · how deep can you go without a residual?
Where things go subtly wrong
| Bug | Symptom | Diagnosis |
|---|---|---|
| Wrong fan | Loss decreases but slowly; per-layer gradient norms are bimodal. | fan_in for forward, fan_out for backward — most frameworks default to fan_in. For wide networks (d_in ≠ d_out), the choice changes the variance scale by their ratio. |
| Mixed-precision overflow at init | Loss NaN at step 0. | bf16 has fewer mantissa bits; large init values can overflow on multiplication. Init in fp32 then cast to bf16 for compute. |
| Embedding init too large | First few layers' activations dominated by embedding; rest of net is along for the ride. | Embedding init ~ N(0, 1) is standard. Some recipes use 0.02 (smaller). Match what the rest of the model expects. |
| LayerNorm γ init not 1 | Loss starts much higher than expected and takes 1k+ steps to recover. | γ should be 1.0 so LN initially passes through. Some custom implementations init γ randomly — bug. |
| BatchNorm in eval mode without running stats | NaN at inference. | BN running stats are populated during training. If you save the model before sufficient steps, running stats are bad. Fix: train longer, or load training-time stats from a forward sweep. |
Interview prompts you should be ready for
- "Derive Kaiming init from variance preservation." (For Y = Wx with x having second moment E[x²] = σ², W ~ N(0, σ_W²), E[Y²] = d_in · σ_W² · σ². After ReLU, half the mass is zeroed: E[ReLU(Y)²] = (1/2) · d_in · σ_W² · σ². For second-moment preservation, set σ_W² = 2/d_in. Note: we track second moments (not variance) because ReLU's output has a positive mean. Standard derivation expected on a whiteboard in 60 seconds.)
- "Why don't residual networks suffer from vanishing gradients?" (∂L/∂x_l contains a "+I" term from the residual, so the gradient at layer l includes an unmodified copy of the gradient at l+1. The block contribution adds; it doesn't multiply. This makes vanishing much harder, though deep enough stacks still need norm and zero-init tricks to remain stable.)
- "You init a 100-layer ResNet and it diverges. Diagnose." (Common: residual sum grows in depth without a final norm; or LayerNorm γ initialised wrong; or per-block scale not initialised small. Fix: init each block's final linear to zero (or to small scale), add LayerScale, ensure a final LN before the output projection.)
- "Why does the LM-head init matter less than hidden-layer init?" (It matters a lot, just differently. For the LM head, the output is the logits, which feed into softmax → cross-entropy. If the LM head is too large at init, all logits are large in magnitude → softmax saturates → vanishing gradient at the output. Init it small or zero so initial loss is ~log(V) and gradients flow.)
- "μP vs Kaiming — what's the same and what's different?" (Kaiming preserves activation variance through a single layer. μP preserves activation variance and parameter update magnitude across scales — so the optimal LR doesn't change with width. μP is Kaiming plus per-layer LR scaling plus a smaller LM-head init.)
- "What's the right init for a SwiGLU FFN?" (SwiGLU has three matrices: W₁ (gate), W₂ (up), W₃ (down). W₁, W₂ should be init ~ N(0, 1/d) so the input to SiLU has unit variance. W₃ (down) can be zero-init for "block starts as no-op" stability, or N(0, 1/(8d/3)) for balanced. Modern recipes: zero-init W₃.)