vllm_lessons / 07 · speculative decoding lesson 7 / 12

Speculative decoding

Draft K cheap tokens, verify them in one parallel target forward, keep the prefix that survives a ratio test. Same output distribution as plain decode, ~2× faster.

The asymmetry that makes it work

Lesson 01, Consequence B: decode is memory-bound. Each step the target model reads its whole KV cache once and does one token's worth of matmul. The compute is tiny relative to the bytes moved.

That has an underappreciated corollary. A K-token parallel forward through the target model takes essentially the same wall-clock time as a 1-token forward:

ttarget(K parallel tokens) ≈ ttarget(1 token)

Why? Both forwards move the KV cache the same amount of bytes (one pass over the full cache). The compute scales with K but compute was never the bottleneck. So a K-token parallel verify is essentially free.

Speculative decoding exploits exactly this: instead of generating 1 token per target step, generate K cheap guesses with a small model, then verify them all in one expensive-but-not-K×-expensive target step.

The algorithm

From Leviathan et al. 2023 (and concurrently Chen et al. at DeepMind). Given target MT, draft MD, speculative length K.

One round:

  1. Draft K tokens sequentially from the small model:
    for i = 1..K: qi = pD(· | ctx + x1..i-1); xi ~ qi
  2. Target verifies in ONE parallel forward over K+1 positions:
    pTi(·) = pT(· | ctx + x1..i-1) for i = 1..K+1
  3. Accept-reject by the ratio test. For each i = 1..K, sample r ~ U(0,1):
    accept xi iff r < min(1, pTi(xi) / qi(xi))
  4. On first rejection at position j, resample from the residual:
    p̃(x) = max(0, pTj(x) − qj(x)) / Z,  Z = Σx max(0, pTj(x) − qj(x))
    Append this corrected token; discard the rest of the draft.
  5. If all K accepted, sample one bonus token from pTK+1. Round yields K+1 tokens.
The non-obvious bit
Step 4's residual is exactly what makes the whole thing an exact sampler of pT. Skip the proof below and you'll either disbelieve the unbiasedness or implement the rejection step wrong.

Unbiasedness — why this is exact, not approximate

Claim: per round, the distribution of the first emitted token equals pT exactly. (The argument iterates: condition on context and apply the same logic to position 2, and so on for accepted positions.)

For any token x, the round emits x at position 1 via one of two disjoint events:

P(x emitted) = P(x proposed & accepted) + P(some y proposed, rejected, x resampled)

The first term is direct from the rule:

P(x proposed & accepted) = q(x) · min(1, pT(x) / q(x)) = min(q(x), pT(x))

For the second, total rejection mass is

P(any rejection) = Σy q(y) · (1 − min(1, pT(y)/q(y))) = Σy max(0, q(y) − pT(y)) = Z

(Equivalently, Z = 1 − Σy min(q(y), pT(y)) — the two forms agree because min + max-of-the-difference partitions the probability mass.) Conditional on rejection we resample from p̃(x) = max(0, pT(x) − q(x)) / Z, so

P(rejection & x resampled) = Z · max(0, pT(x) − q(x)) / Z = max(0, pT(x) − q(x))

Adding the two terms:

P(x emitted) = min(q(x), pT(x)) + max(0, pT(x) − q(x))

Case split:

casemin termmax termsum
pT(x) ≥ q(x)q(x)pT(x) − q(x)pT(x)
pT(x) < q(x)pT(x)0pT(x)

Both cases give pT(x). The marginal of the emitted token is exactly the target distribution. ∎

Quality is preserved
Speculative decoding is not a "draft model approximates the target" heuristic. Sample-by-sample, the output stream is statistically indistinguishable from running MT alone. The draft only affects how many target steps you needed, never what you got out.

Interactive · how the residual fills the gap, per token

The widget below shows two paths emitting the same target distribution. For each token x:

The whole red mass on the qD side has nowhere to go on the accept path — that's the rejection probability Z. The residual max(0, pT − qD) / Z exactly redistributes those rejections so the totals match. Watch the yellow on empirical refill what the red rejected.

Residual resampling · live decomposition (vocab size = 8)
Triple bars per token: p_T (target) │ q_D (draft) │ empirical. Hit step a few times, then run 1000.
rejection mass Z
samples drawn
0
accept rate α
L1(empirical, pT)
show the per-token identity this widget illustrates
// For any vocab item x:
accept_mass(x)   = min(p_T(x), q_D(x))         // emitted via accept path
residual_mass(x) = max(0, p_T(x) - q_D(x))     // emitted via reject→resample path

// Algebraic identity (proof above):
accept_mass(x) + residual_mass(x) = p_T(x)     for every x

// And by complementary symmetry on q_D:
accept_mass(x) + reject_mass(x)   = q_D(x)     where reject_mass(x) = max(0, q_D(x) - p_T(x))
Z = sum(reject_mass) = sum(residual_mass)      // the algorithm's rejection probability

Speedup math

Let α be the per-position acceptance probability, assumed iid across the K positions in a round (a reasonable working assumption — empirical α drifts only weakly with depth). Then the number of accepted prefix tokens is geometric-truncated:

E[accepted prefix length] = Σi=0..K-1 αi = (1 − αK) / (1 − α)

Every round we additionally emit one extra token — either the residual-resampled correction on first reject, or the bonus sample on a full-accept. So:

E[tokens / round] = (1 − αK) / (1 − α) + 1

The cost of a round, measured in target-forward equivalents, where c = tdraft / ttarget:

cost / round = 1 + K · c (one target verify + K sequential draft steps)

Speedup over plain decode (which emits 1 token per target step):

speedup(α, K, c) = [(1 − αK) / (1 − α) + 1] / (1 + K · c)

Worked example

α = 0.7, K = 5, c = 0.2 (draft is 5× cheaper than target):

E[tokens/round] = (1 − 0.7^5) / (1 − 0.7) + 1
                = (1 − 0.16807) / 0.3 + 1
                = 2.7731 + 1
                = 3.7731

cost/round      = 1 + 5 · 0.2 = 2.0

speedup         = 3.7731 / 2.0 ≈ 1.89×

That's right in the middle of reported production numbers (1.5-3×). Greedy decoding at temperature 0 gives the highest α and so the best speedup — the draft and target agree more often when both pick argmax.

Variants worth knowing

variantideatradeoff
Medusa (Cai et al. 2024) Add K small "decoding heads" to the target itself; each predicts (pos+1), (pos+2), … No separate draft model, no separate KV cache. Trains the heads on top of the frozen target.
EAGLE (Li et al. 2024) Draft model takes the target's penultimate-layer features as input, not raw token IDs. Higher α — features carry more predictive info than the discrete sampled token.
Lookahead (Fu et al. 2024) No draft model. Jacobi fixed-point iteration on the target's own K-step conditional. Trades draft cost for extra target FLOPs; useful when no good small model exists.
Tree spec (Sequoia, Medusa-tree) Draft a tree of branches, not a single line. Verify all branches in one target pass; take the longest accepted path. Higher expected accepted length per round at fixed K; needs a kernel that supports custom tree attention masks.

When speculative decoding does not help

Three regimes where you lose

Interactive · the speedup landscape

The speedup formula has three knobs. Slide them and watch the surface. The heatmap below plots speedup(α, K) for the c you set. Sweet spots and dead zones are immediately visible.

Things to find:

Speedup landscape, speedup(α, K) for fixed c
Green = speedup > 1, blue = < 1. White dot = your current (α, K). The colored stripe to the right reads off your exact value.
E[tokens/round]
cost/round
speedup
optimal K @ this α,c

Empirical unbiasedness check

Run 5000 rounds of speculative decoding on a synthetic (target, draft) pair and compare the empirical distribution of the first emitted token against the true pT. L1 should converge to 0.

L1(empirical, p_T)
rounds run
mean accepted
show the formula being plotted
// expected tokens per round (geometric-truncated + 1 bonus/correction):
const Etok = (a === 1) ? K + 1 : (1 - Math.pow(a, K)) / (1 - a) + 1;
// cost per round in target-step-equivalents:
const cost = 1 + K * c;
// speedup vs plain decode:
const speedup = Etok / cost;