deep_learning / 11 · initialization lesson 11 / 12

Initialization & gradient stability

Until 2010, training deep networks was a lottery. Then a one-line insight — "preserve variance through layers" — turned a research mystery into engineering. The same logic underlies Kaiming init, residuals, LayerNorm, and μP.

The problem in one diagram

Forward through a deep linear net y = W_L W_{L-1} … W_1 x. If each W_l is initialised so that ‖W_l x‖ ≈ 2‖x‖, then by layer L: ‖y‖ ≈ 2^L ‖x‖. For L=80, that's 2^{80} ≈ 10^{24}. Activations blow up, loss is NaN. Conversely, if each ‖W_l x‖ ≈ 0.5 ‖x‖, by L=80, activations are 10^{-24}. Gradients vanish. The init must preserve magnitudes.

The principle, stated once
Initialise so that the variance of activations is preserved as you go forward, and the variance of gradients is preserved as you go backward. Then deep networks train. Every "init scheme" — Xavier, Kaiming, μP, ReZero, T-Fixup — is a way of doing exactly this for a particular layer type.

Derive Xavier (Glorot) init

Take a linear layer y = W x with x ∈ ℝ^{d_in}, W ∈ ℝ^{d_out × d_in}. Assume x's entries are iid with variance 1, and W's entries are iid with zero mean and variance σ_W².

Var(y_i) = Var( Σ_j W_{ij} x_j ) = Σ_j Var(W_{ij} x_j)   — independence = d_in · σ_W² · 1   — each term has variance σ_W² · 1 = d_in · σ_W²

For Var(y) = Var(x) = 1, choose σ_W² = 1/d_in. This is Xavier (forward).

Symmetric argument on backprop: ∂L/∂x = W^⊤ ∂L/∂y. Var(∂L/∂x_j) = d_out · σ_W². For variance preservation backward, σ_W² = 1/d_out.

Glorot's compromise: σ_W² = 2/(d_in + d_out). Equivalently, sample from U[-√(6/(d_in+d_out)), +√(6/(d_in+d_out))]. This is the "Xavier" init in most frameworks.

Kaiming (He) init — accounting for ReLU

Xavier assumed a linear layer. ReLU kills half the inputs (those with z < 0). The right derivation tracks the second moment, not the variance (ReLU's output has a positive mean, so Var(ReLU(z)) ≠ E[ReLU(z)²]):

E[ReLU(z)²] = (1/2) · E[z²]   — for symmetric mean-zero z, half the mass is zeroed

So if we want E[y²] preserved through linear-then-ReLU, we double the weight variance: σ_W² = 2/d_in. This is Kaiming init (He et al. 2015). Two flavours:

For other activations:

ActivationGainσ_W²
Linear11/d_in (Xavier)
tanh5/3(5/3)²/d_in
ReLU√22/d_in (Kaiming)
GeLU / SiLU≈√2~2/d_in
Sigmoid1 (small effect)~1/d_in

Why residuals "rescue" deep training

Even with Kaiming init, gradients through L layers compound: each Jacobian's eigenvalues distribute around 1, but products of random matrices have eigenvalue distributions that drift away from 1 at depth. Empirically, plain feedforward nets become hard to train past ~30 layers regardless of init.

The residual connection x_{l+1} = x_l + Block(x_l) changes the geometry. The backward pass:

∂L/∂x_l = ∂L/∂x_{l+1} · (I + ∂Block/∂x_l)

The I term ensures the gradient at layer l contains a copy of the gradient at layer l+1, unscaled. Even if the block's Jacobian is small, the gradient flows through the identity. The identity path provides a non-vanishing lower bound on per-layer gradient magnitude — vanishing is much harder, though not strictly impossible at extreme depth (Bachlechner's ReZero motivation).

Combined with LayerNorm (which keeps activations at unit scale), this is why transformers train at L=80 without exotic init.

The output-layer trick — zero init

For very deep models, even residual + norm can struggle at init: each block does some random thing, and the sum across 80 blocks is large. Modern fix: initialise the final projection of each block to zero. Then at step 0, every block outputs 0, and the residual stream is identity. The model learns up from there.

Variants:

The interview soundbite
Each block should start as approximately the identity function. Zero-initing the final linear in each sub-block achieves this. The model then learns to do useful work, layer by layer, in a stable regime. Without this, the random init of all 80 blocks compounds into a large random noise — the model spends the first ~100 steps undoing it before any learning starts.

μP — the init that lets you sweep LR at small scale

From the scaling laws lesson: standard parameterisation has the optimal LR drift with width. μP fixes this with rescaled init and rescaled per-layer LR. The key rescalings:

LayerStandard initμP initμP LR (relative)
EmbeddingO(1)O(1)1
Hidden layersO(1/√d)O(1/√d)1/d
LM head (readout)O(1/√d)O(1/d)1

With these scalings, the optimal LR found at width 256 transfers to width 8192. This is μTransfer — the only practical way to do hyperparameter sweeps for frontier LLMs.

Loss spikes — diagnosing one in a real run

You're training a 70B model. At step 4283, loss spikes from 2.1 to 87, then NaN. Walk through:

  1. Check gradient norm. Did it explode? If yes — outlier batch (look at the batch's input). If no — numerical issue downstream.
  2. Check activation norms per layer. Is one layer's pre-activation orders of magnitude larger than others? Possible: dead init in that layer, or accumulated drift from earlier issue.
  3. Check learning rate. Is it currently at peak? Spikes often happen right after warmup ends.
  4. Check optimizer state. Has any dropped to a value that's making divisions explosive?
  5. Mitigations: (a) gradient clipping at ‖g‖ ≤ 1.0; (b) skip the bad batch and continue; (c) reduce peak LR; (d) lower β₂ on Adam to make the optimizer more responsive.

Industry practice: save checkpoints frequently, resume from before the spike, skip the offending batch range. Don't try to "train through" a spike; the optimizer state is usually contaminated.

Interactive · how deep can you go without a residual?

Forward activation magnitude through a deep ReLU net
Vary depth and init scale. Without residuals, magnitudes explode or vanish quickly. Toggle "residual + norm" to see why modern transformers are train-able at depth 80.
final activation std
final / initial
trainable?
Reading

Where things go subtly wrong

BugSymptomDiagnosis
Wrong fan Loss decreases but slowly; per-layer gradient norms are bimodal. fan_in for forward, fan_out for backward — most frameworks default to fan_in. For wide networks (d_in ≠ d_out), the choice changes the variance scale by their ratio.
Mixed-precision overflow at init Loss NaN at step 0. bf16 has fewer mantissa bits; large init values can overflow on multiplication. Init in fp32 then cast to bf16 for compute.
Embedding init too large First few layers' activations dominated by embedding; rest of net is along for the ride. Embedding init ~ N(0, 1) is standard. Some recipes use 0.02 (smaller). Match what the rest of the model expects.
LayerNorm γ init not 1 Loss starts much higher than expected and takes 1k+ steps to recover. γ should be 1.0 so LN initially passes through. Some custom implementations init γ randomly — bug.
BatchNorm in eval mode without running stats NaN at inference. BN running stats are populated during training. If you save the model before sufficient steps, running stats are bad. Fix: train longer, or load training-time stats from a forward sweep.

Interview prompts you should be ready for

  1. "Derive Kaiming init from variance preservation." (For Y = Wx with x having second moment E[x²] = σ², W ~ N(0, σ_W²), E[Y²] = d_in · σ_W² · σ². After ReLU, half the mass is zeroed: E[ReLU(Y)²] = (1/2) · d_in · σ_W² · σ². For second-moment preservation, set σ_W² = 2/d_in. Note: we track second moments (not variance) because ReLU's output has a positive mean. Standard derivation expected on a whiteboard in 60 seconds.)
  2. "Why don't residual networks suffer from vanishing gradients?" (∂L/∂x_l contains a "+I" term from the residual, so the gradient at layer l includes an unmodified copy of the gradient at l+1. The block contribution adds; it doesn't multiply. This makes vanishing much harder, though deep enough stacks still need norm and zero-init tricks to remain stable.)
  3. "You init a 100-layer ResNet and it diverges. Diagnose." (Common: residual sum grows in depth without a final norm; or LayerNorm γ initialised wrong; or per-block scale not initialised small. Fix: init each block's final linear to zero (or to small scale), add LayerScale, ensure a final LN before the output projection.)
  4. "Why does the LM-head init matter less than hidden-layer init?" (It matters a lot, just differently. For the LM head, the output is the logits, which feed into softmax → cross-entropy. If the LM head is too large at init, all logits are large in magnitude → softmax saturates → vanishing gradient at the output. Init it small or zero so initial loss is ~log(V) and gradients flow.)
  5. "μP vs Kaiming — what's the same and what's different?" (Kaiming preserves activation variance through a single layer. μP preserves activation variance and parameter update magnitude across scales — so the optimal LR doesn't change with width. μP is Kaiming plus per-layer LR scaling plus a smaller LM-head init.)
  6. "What's the right init for a SwiGLU FFN?" (SwiGLU has three matrices: W₁ (gate), W₂ (up), W₃ (down). W₁, W₂ should be init ~ N(0, 1/d) so the input to SiLU has unit variance. W₃ (down) can be zero-init for "block starts as no-op" stability, or N(0, 1/(8d/3)) for balanced. Modern recipes: zero-init W₃.)
Takeaway
Init is variance preservation: Xavier for linear, Kaiming (×√2) for ReLU. Residuals make vanishing gradients impossible because the identity path always flows. LayerNorm + zero-init final block layers make 80-layer transformers train. μP makes hyperparameters transfer across model sizes. Loss spikes mean check gradient norm, activation norm, optimizer state — in that order. The interview signal is being able to derive Kaiming and then explain why every modern trick (residual, norm, zero-init, μP) is the same principle in a different layer.