Chunked prefill
Prefill is compute-bound; decode is memory-bound. Serialize them and you saturate one bottleneck at a time. Pack them in one fused pass and you fire both engines simultaneously.
The bottleneck table you already know
Lesson 01 noted, and lesson 08 dramatized, that the two phases of inference have opposite limits:
| phase | FLOPs / step | HBM bytes / step | bottleneck |
|---|---|---|---|
| prefill (T tokens) | O(T2) per layer | O(T) for weights + activations | compute |
| decode (1 token) | O(t) per layer | O(t) — full KV read | memory |
Run them on the same GPU back-to-back: while prefill is grinding through matmuls, the HBM bandwidth lane is sitting at 10% utilization. While the in-flight decoders are streaming KV out of HBM, the tensor cores are sitting at 5% utilization. You bought the GPU; you're using half of it.
The idea
Two moves:
- Break each prefill into fixed-size chunks (say 512 tokens). The math is unchanged — attention still attends within the prompt — but the work is now in 512-token bites instead of one 4096-token block.
- Within a single fused forward pass, pack a mix of (decode-step tokens, prefill-chunk tokens) up to a token budget. One step, one kernel launch, two workloads.
Because prefill chunks are FLOP-heavy and decode steps are HBM-heavy, batching them complements the GPU's two scarce resources. The two phases hide each other's latency in the same step.
Why this works · the arithmetic
Let a step process N tokens total. The kernel's wall time is roughly:
For a pure decode step: N is small (1 per active sequence), the right term dominates. For a pure prefill of a long prompt: N is large, the left term dominates. Stack a prefill chunk on top of Ndecode decode tokens in one step: HBM bandwidth runs at near 100% on behalf of the decoders, and the same tensor cores run at near 100% on behalf of the prefill chunk. The two costs are added in a max, not a sum.
Scheduler, in 12 lines
budget_left = TOKEN_BUDGET # e.g. 2048
# 1. Decodes go first — each needs exactly 1 slot, none can be split.
tasks = []
for r in active_decoders:
tasks.append(DecodeTask(r, 1))
budget_left -= 1
# 2. Fill remaining budget with chunks from waiting prefills.
for r in waiting_prefills:
if budget_left == 0: break
chunk = min(budget_left, r.prefill_left)
tasks.append(PrefillTask(r, chunk))
r.prefill_left -= chunk
budget_left -= chunk
if r.prefill_left == 0:
r.move_to_decoding() # next step, this request joins active_decoders
run_fused_forward(tasks)
That's the entire algorithm. Real vLLM adds priorities, SLO awareness, and per-sequence limits, but the skeleton is right.
The fused kernel shape
Tasks are concatenated along a flat token dim:
X: (sum_of_task_tokens, d) # packed input embeddings
# Per-task metadata, length = n_tasks
seq_start: (n_tasks,) # where each task starts in X
seq_len: (n_tasks,) # how many tokens this task owns
block_table: list of block-id arrays # paged KV (see lesson 02)
Inside the attention op each task dispatches its tokens against its own KV: past blocks for ongoing decodes, past chunks + current chunk for in-progress prefills. Attention math is unchanged — only the gather pattern differs, and the paged-KV block table makes that trivial.
Interactive · zoom into one fused step
The picture below is one kernel launch. The wide bar is the token budget; each cell is a token. Blue cells are decode tokens — one per active decoder, each reading its own full KV cache from HBM. Colored stretches are prefill chunks, sliced off waiting requests. The dim tail is wasted budget.
Move the sliders: see how the compute lane (filled by the count of tokens) and the HBM lane (filled mostly by the count of active decoders reading their full KV) both light up only when the step is packed with both kinds of work.
The budget tradeoff
The single tunable is max_num_batched_tokens. Two failure modes:
Empirical sweet spot: 2048–8192 on H100/A100 for 7B–70B models. The widget below lets you feel the cliff.
TTFT impact · why this is a big deal for long prompts
Time to first token, by definition, is "how long after I send the request do I get token 1". In the plain regime:
For P = 4096 on a 7B at ~10 GB/s prefill: prefill_time ≈ 100 ms. The new request waits the full 100 ms before seeing anything.
With chunked prefill at budget 512:
That's ~13 ms for the chunk vs the 100 ms full prefill — an ~8× TTFT win on a 4096-tok prompt. For short prompts (P ≤ chunk_size) it's a wash; the chunk is the prefill.
Caveats and what doesn't change
- Position IDs. The request tracks its logical position across chunks. Chunk 2 of an 8-chunk prefill knows it's tokens 512–1023, not 0–511.
- KV allocation. Blocks are allocated as the chunks land, not up front. Paged KV (lesson 02) makes this a free win — without it you'd need contiguous allocation for an unknown prefill length.
- Causal mask across chunks. Chunk 2's token at logical position 600 attends back to KV that was written in chunk 1. The block table looks it up transparently; there's no special-case code.
- Not free at low load. If your GPU is idle and one request shows up, chunking just adds launch overhead. Plain prefill is fine. The win is when you have a steady stream of decoders to interleave with.
Interactive · pack the budget
Animate the scheduler for 20 steps on a workload of mixed prompt lengths. Each step is a horizontal bar showing how that step's budget tokens were spent: blue for decode tokens (one per active decoder), orange for prefill chunks, dim for unused budget. In plain mode prefills consume the entire budget in one step; in chunked mode the bars fill consistently.
Putting it together
| regime | compute lane | HBM lane | GPU util | TTFT for new long-prompt req |
|---|---|---|---|---|
| serialized (plain) | full during prefill, idle during decode | idle during prefill, full during decode | ~50% | prefill_time |
| chunked + co-batched | full each step (chunk fills FLOPs) | full each step (decoders fill bandwidth) | ~85% | 1 chunk_time |