Distillation foundations
A small model can learn more from a big model's answer than from the right answer. Soft targets, the temperature trick, why it works, and what to match.
The one idea, and three knobs
We deploy small models but train large ones. Knowledge distillation bridges the gap: instead of training a small student on the hard labels in a dataset, train it to match a large teacher's function — its full output distribution. A hard label says "class 7." A teacher says "probably 7, plausibly 2, definitely not 4," and that richer, denser signal is what the student learns from.
Every method in the field — classic image-classifier distillation, DistilBERT, the R1-distilled reasoning models — is one choice along three orthogonal knobs:
- WHAT to match — the teacher's response (logits), its internal features, the relations between examples, or whole sequences.
- HOW to measure the gap — temperature, forward vs reverse KL, JS / f-divergence.
- WHERE the data comes from — a fixed set, the teacher's own generations, or the student's on-policy rollouts.
Dark knowledge
Why is a distribution better than a label? Because the teacher's soft output encodes similarity structure the one-hot label throws away. Shown a handwritten 7, a good teacher might output p = (7\!:\!0.70,\ 2\!:\!0.20,\ 1\!:\!0.07,\ 9\!:\!0.02,\ …). The fact that a 7 is ~10× more confusable with a 2 than with a 9 is real information about the input — Hinton called it dark knowledge. The entropy of that target,
carries over a bit more signal per example than the 0 bits of a label you already knew. Two consequences: a soft target is K constraints per example instead of one, and it is a far smoother, lower-variance regression target than a noisy one-hot sample. Capacity was rarely the bottleneck — a small student can usually represent the function; the hard part is finding it, and the teacher hands over a well-behaved target that makes the search easy.
Soft targets & temperature
There is a snag: a well-trained teacher is confident. Its softmax often reads 0.999, 0.0009, 1e-7, … — the dark knowledge is there but buried below where cross-entropy can see it. The fix is a temperature T dividing the logits before the softmax:
T = 1 is the native distribution; T → 0⁺ collapses to the one-hot argmax; T → ∞ flattens to uniform. Raising T from 1 softens the distribution, lifting the buried tail into a range where it materially affects the loss. The distillation objective is a weighted sum of a hard and a soft term:
The hard term anchors the student to ground truth (at T=1); the soft term transfers dark knowledge; α ∈ [0,1] trades them off (typically 0.5–0.9 toward the soft term; α=1 when you have no labels for the transfer set). The mysterious T² is not a tuned constant — it falls out of the gradient.
Why the T² factor
Differentiate the soft cross-entropy C = CE(q(T), p(T)) with respect to a student logit vi. The chain rule through vi/T gives a 1/T out front:
At high T, Taylor-expand exp(zi/T) ≈ 1 + zi/T; with zero-meaned logits the softmax becomes qi(T) ≈ 1/K + zi/(KT) (same for the student). The 1/K cancels in the difference, leaving
The soft-term gradient shrinks like 1/T². So raising T to expose dark knowledge would silently turn the soft loss off and let the hard term dominate. Multiplying the soft term by T² exactly cancels the shrinkage, so the soft/hard balance stays fixed as you sweep temperature. (A missing or misplaced T² is the most common bug in homegrown distillation code.)
That same expansion gives a clean limit: a gradient proportional to (vi − zi) is the gradient of a squared error, so
At its softest extreme, distillation is just L2 regression of the student's logits onto the teacher's — the most direct possible function matching. It also explains why distillation beats label smoothing: both fight overconfidence, but label smoothing spreads a fixed, uniform ε over wrong classes, while distillation's softening is learned and per-example — it knows "dog" is a better wrong answer than "car" for a given cat.
Widget — the temperature knob
Six fixed logits. Slide T and watch mass leak from the peak into the tail, entropy climb, and the top-2 gap shrink. The two gradient read-outs show the soft term fading as 1/T² without the correction, and staying flat with it.
Why it works — four lenses, none complete
There is no single reason distillation works; there are four good partial ones, and they sometimes disagree. The method outruns the theory, so keep all four on the table.
| Lens | Mechanism | Where it fails |
|---|---|---|
| 1 · Dark knowledge | soft target = many constraints + variance reduction | silent on capacity gap; weak when target ≈ one-hot |
| 2 · Regularization | adaptive, per-example label smoothing curbs overconfidence | doesn't explain gains from very long training |
| 3 · Function matching | smooth teacher target turns a jagged optimization into an easy one | needs lots of compute; capacity gap still binds for tiny students |
| 4 · Bayesian | teacher ≈ posterior predictive P(y|x); student regresses on it | most teachers are poorly calibrated, not true posteriors |
The WHAT knob: response, feature, relation
The simplest recipe is response distillation — match the teacher's output distribution, exactly the loss above. It is black-box (needs only the teacher's outputs, even just sampled text) and unlabeled transfer data works. The core loop:
teacher.eval()
for x, y in transfer_loader:
with torch.no_grad(): zT = teacher(x)
zS = student(x)
soft = kl_div(log_softmax(zS/T), softmax(zT/T)) * (T*T)
hard = cross_entropy(zS, y) if y is not None else 0.0
loss = (1 - alpha) * hard + alpha * soft
loss.backward(); opt.step(); opt.zero_grad()
(For LLMs with a 50k vocab, store only the top-k teacher probabilities — 8–128 — per position and renormalize; the tail is mostly noise and the full logits would be petabytes.) DistilBERT is the canonical case: 6 layers vs 12, ~60% faster, ~97% of GLUE retained, distilled during pretraining with a combined KL + MLM + cosine-embedding loss.
When logits are too thin a pipe (a capacity gap stalls response KD), reach inside the teacher:
- Feature distillation (FitNets) matches a hidden layer through a learned projector r(·) that bridges the width gap: Lhint = ‖r(hS) − hT‖². Attention transfer is the projector-free version — collapse activations to a 2D attention map and match that.
- Relational distillation (RKD) matches the geometry between examples — pairwise distances and angles — rather than absolute activations: ψ(a,b) = ‖a−b‖ / μ. The key argument is invariance: rotating or scaling the student's feature space preserves all downstream computation, so a pointwise loss wrongly explodes under it while a relational loss correctly stays at zero. Only true geometry mismatch is penalized.
Feature and relational KD are white-box (need the teacher's internals) and usually ride on top of the response loss rather than replacing it.
When the teacher is the student, a peer, or a crowd
Distillation is sold as compression, but it is independently a regularizer. Born-again networks make this concrete: a same-size student trained on a teacher's soft targets often beats the teacher, because the learned, per-example softening steers it to a flatter, better-generalizing minimum. Three relatives drop the pretrained teacher entirely: self-distillation (the teacher is the model's own EMA / deeper layers — an EMA is a low-pass filter over the SGD iterates that lands in a flatter region), deep mutual learning (a cohort of students teach each other toward consensus), and codistillation (the distributed version, exchanging outputs instead of gradients). The recurring discipline: ablate against a label-smoothing baseline, because some "self-distillation" gains are just smoothing in a costume.