Speculative decoding
Draft K cheap tokens, verify them in one parallel target forward, keep the prefix that survives a ratio test. Same output distribution as plain decode, ~2× faster.
The asymmetry that makes it work
Lesson 01, Consequence B: decode is memory-bound. Each step the target model reads its whole KV cache once and does one token's worth of matmul. The compute is tiny relative to the bytes moved.
That has an underappreciated corollary. A K-token parallel forward through the target model takes essentially the same wall-clock time as a 1-token forward:
Why? Both forwards move the KV cache the same amount of bytes (one pass over the full cache). The compute scales with K but compute was never the bottleneck. So a K-token parallel verify is essentially free.
Speculative decoding exploits exactly this: instead of generating 1 token per target step, generate K cheap guesses with a small model, then verify them all in one expensive-but-not-K×-expensive target step.
The algorithm
From Leviathan et al. 2023 (and concurrently Chen et al. at DeepMind). Given target MT, draft MD, speculative length K.
One round:
- Draft K tokens sequentially from the small model:
for i = 1..K: qi = pD(· | ctx + x1..i-1); xi ~ qi
- Target verifies in ONE parallel forward over K+1 positions:
pTi(·) = pT(· | ctx + x1..i-1) for i = 1..K+1
- Accept-reject by the ratio test. For each i = 1..K, sample r ~ U(0,1):
accept xi iff r < min(1, pTi(xi) / qi(xi))
- On first rejection at position j, resample from the residual:
p̃(x) = max(0, pTj(x) − qj(x)) / Z, Z = Σx max(0, pTj(x) − qj(x))Append this corrected token; discard the rest of the draft.
- If all K accepted, sample one bonus token from pTK+1. Round yields K+1 tokens.
Unbiasedness — why this is exact, not approximate
Claim: per round, the distribution of the first emitted token equals pT exactly. (The argument iterates: condition on context and apply the same logic to position 2, and so on for accepted positions.)
For any token x, the round emits x at position 1 via one of two disjoint events:
The first term is direct from the rule:
For the second, total rejection mass is
(Equivalently, Z = 1 − Σy min(q(y), pT(y)) — the two forms agree because min + max-of-the-difference partitions the probability mass.) Conditional on rejection we resample from p̃(x) = max(0, pT(x) − q(x)) / Z, so
Adding the two terms:
Case split:
| case | min term | max term | sum |
|---|---|---|---|
| pT(x) ≥ q(x) | q(x) | pT(x) − q(x) | pT(x) |
| pT(x) < q(x) | pT(x) | 0 | pT(x) |
Both cases give pT(x). The marginal of the emitted token is exactly the target distribution. ∎
Interactive · how the residual fills the gap, per token
The widget below shows two paths emitting the same target distribution. For each token x:
- The pT bar is split into green (mass arriving via accept — equal to min(pT, qD)) and yellow (mass arriving via residual resample — equal to max(0, pT − qD)).
- The qD bar is split into green (proposals that will be accepted — same height as the green on pT) and red (proposals that will be rejected — equal to max(0, qD − pT)).
- The empirical bar is what the algorithm has actually emitted. Click step repeatedly — green grows from accept events, yellow grows from rejection-then-resample events. Their sum per token tracks pT(x).
The whole red mass on the qD side has nowhere to go on the accept path — that's the rejection probability Z. The residual max(0, pT − qD) / Z exactly redistributes those rejections so the totals match. Watch the yellow on empirical refill what the red rejected.
Speedup math
Let α be the per-position acceptance probability, assumed iid across the K positions in a round (a reasonable working assumption — empirical α drifts only weakly with depth). Then the number of accepted prefix tokens is geometric-truncated:
Every round we additionally emit one extra token — either the residual-resampled correction on first reject, or the bonus sample on a full-accept. So:
The cost of a round, measured in target-forward equivalents, where c = tdraft / ttarget:
Speedup over plain decode (which emits 1 token per target step):
Worked example
α = 0.7, K = 5, c = 0.2 (draft is 5× cheaper than target):
E[tokens/round] = (1 − 0.7^5) / (1 − 0.7) + 1
= (1 − 0.16807) / 0.3 + 1
= 2.7731 + 1
= 3.7731
cost/round = 1 + 5 · 0.2 = 2.0
speedup = 3.7731 / 2.0 ≈ 1.89×
That's right in the middle of reported production numbers (1.5-3×). Greedy decoding at temperature 0 gives the highest α and so the best speedup — the draft and target agree more often when both pick argmax.
Variants worth knowing
| variant | idea | tradeoff |
|---|---|---|
| Medusa (Cai et al. 2024) | Add K small "decoding heads" to the target itself; each predicts (pos+1), (pos+2), … | No separate draft model, no separate KV cache. Trains the heads on top of the frozen target. |
| EAGLE (Li et al. 2024) | Draft model takes the target's penultimate-layer features as input, not raw token IDs. | Higher α — features carry more predictive info than the discrete sampled token. |
| Lookahead (Fu et al. 2024) | No draft model. Jacobi fixed-point iteration on the target's own K-step conditional. | Trades draft cost for extra target FLOPs; useful when no good small model exists. |
| Tree spec (Sequoia, Medusa-tree) | Draft a tree of branches, not a single line. Verify all branches in one target pass; take the longest accepted path. | Higher expected accepted length per round at fixed K; needs a kernel that supports custom tree attention masks. |
When speculative decoding does not help
- Compute-saturated. Large-batch serving where the target is already near peak FLOPs — the K-token parallel forward is no longer ~free. The +K·c draft cost is just pure overhead.
- Low α with expensive draft. Low acceptance erodes the gain, and only pushes below 1× when α is very small and draft cost is high. Plug in α=0.2, K=4, c=0.5: E[tokens/round] = (1 − 0.24) / 0.8 + 1 ≈ 2.248; cost = 1 + 4·0.5 = 3.0; speedup ≈ 0.75×. At c=0.2 you'd still win modestly even at α=0.4 (≈1.32×) — the K·c overhead only dominates when both knobs go bad.
- Tight KV budget. Verifying K+1 positions at once needs that much extra KV scratch. Under preemption pressure (lesson 11) spec decoding can push you into eviction.
Interactive · the speedup landscape
The speedup formula has three knobs. Slide them and watch the surface. The heatmap below plots speedup(α, K) for the c you set. Sweet spots and dead zones are immediately visible.
Things to find:
- Set c = 0.2. Trace the speedup ridge as α rises — it bends toward larger K. At α ≥ 0.9, K = 8-10 dominates. At α = 0.5, the ridge stops at K ≈ 3-4: more guesses are just wasted draft work.
- Set c = 0.5 (slow draft). Most of the heatmap goes cold; only high-α / low-K regions stay green. This is the "no good small model" regime where Lookahead beats classical spec.
- Drop α to 0.3. Almost the whole grid is < 1.0 — you've made the system slower.