The KV cache and PagedAttention — why decode is its own problem
Inside the rollout engine sits a single data structure that decides almost everything about throughput: the KV cache. This lesson opens it up. Why generation is memory-bandwidth-bound, what PagedAttention is and why it changed the field, and how the cache's geometry decides what batch sizes you can run.
Why generation is its own engineering problem
A single forward pass on an L-token prompt is one matmul-heavy GPU kernel. Generating L new tokens is L sequential forward passes, each one consuming the previous token's output. The naive implementation re-runs the full attention over the entire context for every new token — O(L2) total compute and just as bad in memory.
The basic fix is the KV cache: per layer, cache the keys and values computed for every previous token. Each new token's attention then becomes O(L) reads against the cache instead of a recomputation. Generation becomes memory-bandwidth-bound — for every new token, you stream the full KV cache from HBM to the SMs once.
Decode (generate one token at a time): per-step memory traffic = KV-cache size of one sequence. At batch size 1 you bottleneck on HBM bandwidth long before you bottleneck on FLOPs. Batching multiple decoding streams amortizes the weight reads but not the KV reads.
So the inference engine has two opposite-looking workloads stitched into one. Everything in this lesson is about getting the KV cache out of the way of throughput, and everything in lesson 21 is about getting prefill and decode out of each other's way.
The KV cache, sized
Per layer, per token, you store one K and one V vector. With full multi-head attention (MHA):
For Llama-3 architectures with Grouped-Query Attention (GQA), only KV heads count, not query heads — so the cache shrinks by the query-to-KV ratio:
| Model | Layers | d_kv | KV / token (bf16) |
|---|---|---|---|
| Llama-3-8B (GQA 8/32) | 32 | 1024 | ~131 KB |
| Llama-3-70B (GQA 8/64) | 80 | 1024 | ~328 KB |
| Qwen-2.5-7B (GQA 4/28) | 28 | 512 | ~57 KB |
Scale these up by sequence length and batch size and the cache fills HBM fast. A 70B model serving 256 concurrent 8k-token sequences needs 256 × 8192 × 328 KB ≈ 670 GB of cache alone — more than any single GPU has.
The pre-PagedAttention world
Older inference frameworks pre-allocated a contiguous KV cache slab per sequence of max_new_tokens. Three problems followed:
- Worst-case allocation. Most sequences finished early; the unused tail of every slab was wasted HBM.
- Fragmentation. When sequences finished out of order, the slab pool became holes. Concurrent users were rejected even though aggregate free memory was high.
- No sharing. Two sequences with the same prompt couldn't physically share the prompt's KV — each had its own contiguous slab.
For RL specifically, problem 3 is severe. K-rollout-per-prompt sampling means K trajectories sharing one prompt. Pre-PagedAttention, that's K copies of the prompt's KV in HBM.
PagedAttention — KV cache as virtual memory
PagedAttention (vLLM '23) treats the KV cache like an OS page table. Chunk the cache into fixed-size blocks (typically 16 tokens) and maintain a per-sequence block table mapping logical positions to physical blocks. New tokens claim a block from a free pool; finished sequences return their blocks. The attention kernel reads K and V via the block table indirection — one extra pointer dereference per page, negligible cost.
Three consequences worth naming:
- Zero waste. Never reserve memory you might not use; blocks are demand-allocated.
- No fragmentation. Free blocks always coalesce in a single pool; allocation is O(1).
- Block sharing. Two sequences with the same prompt point at the same physical blocks via their block tables. This is the foundation of prefix caching.
Combined, the achievable batch size jumps 4–8× over the contiguous-cache approach.
Without PagedAttention, each sequence's contiguous KV slab is independent and the shared prompt is duplicated K times. With PagedAttention the prefix blocks are physically shared, so a fresh K-rollout-per-prompt batch costs (prompt blocks) + K × (completion blocks) instead of K × (prompt + completion blocks). For typical RL workloads (K=16, prompt ≫ completion in early training), that's a 5–10× memory win on the prompt portion alone — and that memory turns directly into achievable batch size.
RadixAttention — the SGLang generalization
PagedAttention shares at the block level: two sequences either share a full block of prompt or they don't. SGLang's RadixAttention generalizes this to share arbitrary subtrees of tokens. Every prefix the engine has ever computed lives in a radix tree; new requests walk the tree to find the longest matching prefix and reuse its blocks.
For multi-turn conversations and tool-use agents — where many trajectories share not just a system prompt but several earlier turns — RadixAttention catches reuse opportunities that simple block-level prefix caching misses. The two innovations stack: PagedAttention is the data structure, RadixAttention is the indexing scheme over it. Both engines (vLLM, SGLang) ultimately need the paged storage; SGLang adds the tree.
Interactive · how the KV cache fills HBM
Slide concurrency and sequence length. The widget computes the steady-state KV-cache footprint and the per-step KV bandwidth requirement, and reports whether the engine is HBM-capacity-bound or HBM-bandwidth-bound. Toggle GQA to feel its effect.
What the framework actually sees
From RL/framework/rollout.py, the inference engine exposes:
class Rollout:
def generate(prompts, sampling_params) -> List[Trajectory]
def update_weights(state_dict) -> None
def get_old_logp(trajectory) -> Tensor # log-prob at sampling time
Behind generate sits paged storage and (in lesson 21) the scheduler. The framework cares about three production-shaped properties:
- Bulk throughput. 256+ concurrent sequences without ballooning latency.
- Old-logp consistency. The log-prob recorded at sampling time must equal what the trainer would compute on a forward of the same tokens. In fp16/bf16 this means the rollout and trainer kernels need to agree to within ~10−4 on log-prob. If they don't, lesson 06's PPO-ratio symptoms appear.
- Fast weight ingestion. The trainer broadcasts new weights every step (lesson 19). The engine must accept them without dropping the KV cache for in-flight sequences — that would force a recomputation of every prefix.