Backprop & autodiff from first principles
Every deep learning system is one trick: the chain rule, scheduled to traverse the computation graph in reverse. Re-derive it once on a two-layer MLP and you can debug any gradient, anywhere.
The one-line statement
You have a scalar loss L and a chain of operations x → h₁ → h₂ → … → L. The chain rule says:
That's it. Everything else — autograd engines, tape-based autodiff, the cost of "save activations vs recompute" — is engineering around that product of Jacobians.
The reason the chain rule is interesting and not trivial is that for a neural net x ∈ ℝ^d, every "Jacobian" in that product is a matrix, not a scalar. The order in which you multiply those matrices changes the compute cost by orders of magnitude. Backprop is the strict-rightmost evaluation order. That choice is the whole game.
Forward vs reverse — derive the cost asymmetry
Take L: ℝ^d → ℝ built as a chain of k matrix multiplies W_1, W_2, …, W_k, each m × m. The Jacobian of the full computation is the product J = W_k · W_{k-1} · … · W_1. There are two ways to compute ∇L = J^⊤ · 1:
| Mode | Order | Cost | When it wins |
|---|---|---|---|
| Forward (JVP) | Multiply Jacobians left→right, propagating a tangent vector forward. | O(d · m² · k) per input direction. You need d forward passes for the full Jacobian. | Few inputs, many outputs (d ≪ output_dim). Rare in deep learning. |
| Reverse (VJP, "backprop") | Multiply right→left, propagating a cotangent (gradient) backward from the scalar L. | O(m² · k) total — same as one forward pass. | Many inputs, one (or few) outputs (d ≫ 1, output is a scalar loss). All of deep learning. |
The whole point of reverse-mode is that for a scalar output, multiplying Jacobians right-to-left collapses each intermediate to a vector, never a matrix-matrix product. Forward mode would compute the full m × d Jacobian explicitly.
Derive backprop on a two-layer MLP — the canonical interview question
Setup. Input x ∈ ℝ^d, hidden h ∈ ℝ^m, output y ∈ ℝ^o, target t, MSE loss.
Now go backward. Each step is one matrix-vector application of (∂output/∂input)^⊤.
Things to notice — every one of them is an interview question:
- Weight gradients are outer products of (downstream gradient) × (upstream activation). This is why activations from the forward pass must be saved (or recomputed) for the backward pass. You can't compute ∂L/∂W without h.
- Transposed weights propagate the signal back. ∂L/∂h = W₂^⊤ ∂L/∂z₂. This is what people mean when they say "the backward pass uses the transpose of the forward weights." For convolutions, the transpose is a transposed convolution; for attention, it's a softmax derivative composed with a matmul; etc.
- The Hadamard product ⊙ σ'(z₁) is where activation choice bites. ReLU has σ' ∈ {0, 1}, so half the gradient channel is killed exactly (dead ReLU problem). Sigmoid has σ' ≤ 0.25 and saturates to ~0 for |z| large (vanishing gradient).
- Bias gradient is a sum. If you have a batch, ∂L/∂b = Σ_batch ∂L/∂z. This sum is why bias is shape (m,) not (B, m).
Interactive · feel the chain rule
A two-layer MLP with sliders for layer widths. The widget shows the shapes of the forward activations, the shapes of every gradient, and the storage cost. Move the layer widths and watch the activation memory cost dominate.
What "autograd" actually does
PyTorch/JAX/TF autograd is a three-step machine:
- Tape recording. Every op you call (matmul, add, relu) registers itself on a tape — the computation graph. The tape stores the inputs that will be needed for the backward formula. (For
z = W @ x, the formula isdL/dW = dL/dz @ x.T; x must be saved.) - Backward traversal. Starting from
loss.backward(), the engine walks the tape in reverse, calling each op's backward formula. Each op consumes a "grad input" and produces "grad outputs" for its inputs. - Gradient accumulation. If a tensor is used in multiple downstream paths (e.g.
hfeeds both an attention and an FFN), the gradients from all paths sum. This is the only place where "+= " is the correct operator in the entire backward pass.
The implication: every op has two implementations, forward and backward. When you write a custom CUDA kernel, you must write both, or your training run errors. When you fuse kernels (lesson on Triton in system_ml), you fuse the forward; the backward fusion is a separate engineering task.
The activation memory problem — and what to do about it
Activation memory scales as B × L × D × (layers) where B is batch, L is sequence length, D is hidden dim. For a 70B model with L = 8192, B = 1, D = 8192, layers = 80, raw activation memory is roughly 1 × 8192 × 8192 × 80 × 2 bytes ≈ 10.7 GB — and that's per activation tensor, not counting attention intermediates. Real activations are 5–10× that. Activations dominate memory in training.
| Strategy | Memory | Compute | Notes |
|---|---|---|---|
| Save everything | O(L × layers) | O(forward) | Baseline. OOMs on large models. |
| Gradient checkpointing (a.k.a. activation recompute) | O(L × √layers) | ~1.33× forward (one extra forward pass per chunk) | Recompute activations during backward. Square-root memory at ~33% compute overhead. Always on for large LLM training. |
| Selective checkpointing | Custom | Custom | Save the expensive ops (attention matmul), recompute the cheap ones (norm, residual add). Sweet spot in 2026. |
| Reversible nets | O(L) | ~2× forward | Activations are reconstructible from outputs. Used in RevNet, Reformer. Rare in production. |
| CPU offloading | 0 GPU | +PCIe roundtrip per layer | Move activations to CPU between forward and backward. PCIe (~30 GB/s) vs HBM (~3 TB/s) is a 100× slowdown if not overlapped. |
Common gradient bugs — how to spot them in 30 seconds
| Symptom | Likely cause | How to confirm |
|---|---|---|
| Loss is NaN immediately | (1) Bad init (weights too large → exp overflow); (2) log(0) in the loss; (3) division by 0 in normalisation | Check parameter norms at step 0. Check that all losses are finite for a known-good batch. Lower LR or rescale init. |
| Loss is NaN after a few hundred steps | Loss spike (mixed precision underflow / overflow). Often a single bad batch + low LR cooling. | Log gradient norm per step. Spikes > 100× steady-state precede NaNs. Clip gradients; check loss scaling for fp16. |
| Gradients are exactly zero | (1) ReLU dead — all preactivations < 0; (2) softmax saturated → ∂softmax ≈ 0; (3) requires_grad=False on a tensor that should be trained; (4) detach() in the middle of the graph |
Print param.grad.norm() per layer. Check for parameters with grad is None after backward. |
| Loss decreases on train, doesn't on val | Classic overfit. Or: BN running statistics mismatched between train and eval (don't forget model.eval()). |
Standard generalisation diagnostics. If swap eval/train changes val by orders of magnitude → BN issue. |
| Loss decreases too smoothly, no batch noise | Batch is identical every step (data loader bug), or LR is so small it's effectively averaging gradients over many steps. | Print a hash of the input batch each step. Should change. |
| One layer's gradient is 1000× larger | The layer is doing real work but the others are dead. Often: bad init or skip connections shorting out a stack. | Per-layer gradient norm. Should be roughly within an order of magnitude of each other. If not — bad init or LR schedule. |
The senior question: when does the gradient lie?
The gradient is the local linearisation of the loss. It is correct to first order. It is wrong in three regimes:
- Sharp losses. Near a kink (ReLU, hinge, max), the gradient is one of a set of subgradients. SGD picks one and hopes. Usually fine; occasionally pathological.
- Stochastic estimators. The gradient you compute is ∇_θ E_x[L(x, θ)] estimated as (1/B) Σ_x ∇_θ L(x, θ). Variance scales as σ²/B. Small batch + sparse gradients (e.g. classification with rare classes) → noisy direction. This is why batch size matters beyond hardware.
- Reparameterisation needed. If you sample inside the network (variational autoencoders, Gumbel-softmax), the gradient through the sample is zero unless you reparameterise. The naïve gradient is correct — it's correctly zero — but uninformative. Reparam tricks rewrite the sample as a deterministic function of a noise variable so the gradient flows through the deterministic path.
The senior signal is being able to say the gradient is locally correct and globally useless in these cases, and to name the fix (subgradient handling, larger batches / control variates, reparameterisation).
Interview prompts you should be ready for
- "Derive the gradient of cross-entropy through softmax." (Answer: ∂L/∂z_i = p_i − y_i, where p = softmax(z). Show that the softmax derivative ∂p_i/∂z_j = p_i(δ_{ij} − p_j) contracts with the one-hot target into p − y. Beautiful cancellation; expected to derive in 90 seconds on a whiteboard.)
- "You add a layer and the loss goes up. Walk me through three diagnoses." (a. Init: new layer has different scale; per-layer gradient norms are bimodal. b. Capacity: more parameters → harder to optimise at the same LR. c. Identity vs not: if the new layer's init isn't near-identity, you've broken the prior solution. Fix: init final layer to zero (LayerScale, ReZero) so the new layer starts as a no-op and learns up.)
- "What's the memory cost of training a 7B model in bf16 with Adam?" (Params: 7B × 2 bytes = 14 GB. Gradients: same, 14 GB. Adam moments m,v: each 7B × 4 bytes (fp32 for stability) = 28 GB × 2 = 56 GB. Optimizer fp32 master copy: 28 GB. Total: ~112 GB just for state, before activations. Doesn't fit on H100 80 GB → need FSDP or offloading.)
- "Why do residual connections help training?" (Several answers, all valid: (a) Identity gradient path → ∂L/∂x_l = ∂L/∂x_L · (1 + ∂(stack)/∂x_l), the 1 prevents the product from collapsing to zero; (b) initialisation: each residual block initially does almost nothing, so the function is approximately identity at init; (c) loss landscape: residual nets have visibly smoother losses, Goodfellow / Li et al. 2018.)
- "What's the difference between
loss.backward()andtorch.autograd.grad(loss, [p])?" (.backward() accumulates intop.gradin-place. torch.autograd.grad returns gradients functionally without touching.grad. The functional form is useful for meta-learning, higher-order gradients, and any case where you don't want side effects on the parameter's grad buffer.) - "Implement a custom autograd Function in PyTorch for
y = x². Why does .Function exist instead of relying on op composition?" (x²doesn't need a custom Function —x * xcomposes. The reason Function exists is when you have a closed-form gradient that's cheaper than autodiff would compute (e.g., the joint forward/backward of attention in FlashAttention), or when you want to break the chain (e.g., straight-through estimator for quantisation). Standard interview cross-check.)