rl_lessons / 20 · KV cache & paged attention lesson 6 / 9 · part III

The KV cache and PagedAttention — why decode is its own problem

Inside the rollout engine sits a single data structure that decides almost everything about throughput: the KV cache. This lesson opens it up. Why generation is memory-bandwidth-bound, what PagedAttention is and why it changed the field, and how the cache's geometry decides what batch sizes you can run.

Where this lesson sits
Lesson 19 decided where rollout and trainer sit on the cluster. This lesson zooms inside the rollout pool's data structure. Lesson 21 takes the same engine and adds the scheduler tricks on top (continuous batching, prefix caching, chunked prefill, spec decode). After that, lesson 22 does the memory math that decides whether engine + trainer fit on the GPUs you have.

Why generation is its own engineering problem

A single forward pass on an L-token prompt is one matmul-heavy GPU kernel. Generating L new tokens is L sequential forward passes, each one consuming the previous token's output. The naive implementation re-runs the full attention over the entire context for every new token — O(L2) total compute and just as bad in memory.

The basic fix is the KV cache: per layer, cache the keys and values computed for every previous token. Each new token's attention then becomes O(L) reads against the cache instead of a recomputation. Generation becomes memory-bandwidth-bound — for every new token, you stream the full KV cache from HBM to the SMs once.

The shape of the workload — prefill vs decode
Prefill (process the prompt): one big matmul per layer, FLOPs-bound, looks like training.
Decode (generate one token at a time): per-step memory traffic = KV-cache size of one sequence. At batch size 1 you bottleneck on HBM bandwidth long before you bottleneck on FLOPs. Batching multiple decoding streams amortizes the weight reads but not the KV reads.

So the inference engine has two opposite-looking workloads stitched into one. Everything in this lesson is about getting the KV cache out of the way of throughput, and everything in lesson 21 is about getting prefill and decode out of each other's way.

The KV cache, sized

Per layer, per token, you store one K and one V vector. With full multi-head attention (MHA):

KV bytes per token  =  2 · n_layers · d_model · bytes_per_element

For Llama-3 architectures with Grouped-Query Attention (GQA), only KV heads count, not query heads — so the cache shrinks by the query-to-KV ratio:

ModelLayersd_kvKV / token (bf16)
Llama-3-8B (GQA 8/32)321024~131 KB
Llama-3-70B (GQA 8/64)801024~328 KB
Qwen-2.5-7B (GQA 4/28)28512~57 KB

Scale these up by sequence length and batch size and the cache fills HBM fast. A 70B model serving 256 concurrent 8k-token sequences needs 256 × 8192 × 328 KB ≈ 670 GB of cache alone — more than any single GPU has.

GQA is non-negotiable for serving
Every modern open-weight model uses GQA or MQA. Multi-head attention's KV cache grows so fast with concurrency that serving large MHA models past 32-batch becomes uneconomic. If a model ships without GQA in 2026, it's targeting research, not serving.

The pre-PagedAttention world

Older inference frameworks pre-allocated a contiguous KV cache slab per sequence of max_new_tokens. Three problems followed:

  1. Worst-case allocation. Most sequences finished early; the unused tail of every slab was wasted HBM.
  2. Fragmentation. When sequences finished out of order, the slab pool became holes. Concurrent users were rejected even though aggregate free memory was high.
  3. No sharing. Two sequences with the same prompt couldn't physically share the prompt's KV — each had its own contiguous slab.

For RL specifically, problem 3 is severe. K-rollout-per-prompt sampling means K trajectories sharing one prompt. Pre-PagedAttention, that's K copies of the prompt's KV in HBM.

PagedAttention — KV cache as virtual memory

PagedAttention (vLLM '23) treats the KV cache like an OS page table. Chunk the cache into fixed-size blocks (typically 16 tokens) and maintain a per-sequence block table mapping logical positions to physical blocks. New tokens claim a block from a free pool; finished sequences return their blocks. The attention kernel reads K and V via the block table indirection — one extra pointer dereference per page, negligible cost.

Three consequences worth naming:

Combined, the achievable batch size jumps 4–8× over the contiguous-cache approach.

Sequence A · logical KV L0 L1 L2 L3 L4 shared prompt (L0,L1,L2) A's unique completion Sequence B · logical KV L0 L1 L2 L3 L4 same prompt (shared) B's unique completion block table block table Physical block pool (HBM) P7 P12 P3 P19 P22 P8 P15 free free free shared prompt blocks (green) are pointed at by BOTH sequences divergent suffix blocks (blue / orange) are unique

Without PagedAttention, each sequence's contiguous KV slab is independent and the shared prompt is duplicated K times. With PagedAttention the prefix blocks are physically shared, so a fresh K-rollout-per-prompt batch costs (prompt blocks) + K × (completion blocks) instead of K × (prompt + completion blocks). For typical RL workloads (K=16, prompt ≫ completion in early training), that's a 5–10× memory win on the prompt portion alone — and that memory turns directly into achievable batch size.

RadixAttention — the SGLang generalization

PagedAttention shares at the block level: two sequences either share a full block of prompt or they don't. SGLang's RadixAttention generalizes this to share arbitrary subtrees of tokens. Every prefix the engine has ever computed lives in a radix tree; new requests walk the tree to find the longest matching prefix and reuse its blocks.

For multi-turn conversations and tool-use agents — where many trajectories share not just a system prompt but several earlier turns — RadixAttention catches reuse opportunities that simple block-level prefix caching misses. The two innovations stack: PagedAttention is the data structure, RadixAttention is the indexing scheme over it. Both engines (vLLM, SGLang) ultimately need the paged storage; SGLang adds the tree.

Interactive · how the KV cache fills HBM

Slide concurrency and sequence length. The widget computes the steady-state KV-cache footprint and the per-step KV bandwidth requirement, and reports whether the engine is HBM-capacity-bound or HBM-bandwidth-bound. Toggle GQA to feel its effect.

KV cache footprint & bandwidth
Per-token KV size depends on layers × d_kv × bytes. Capacity = concurrency × length × per-token. Per-step bandwidth = same divided by step time. Toggle GQA-8 (8 KV heads instead of full MHA) to see the cache shrink ~4×.
KV / token
Total KV (GB)
HBM free for KV
Per-step KV time
Bottleneck

What the framework actually sees

From RL/framework/rollout.py, the inference engine exposes:

class Rollout:
    def generate(prompts, sampling_params) -> List[Trajectory]
    def update_weights(state_dict) -> None
    def get_old_logp(trajectory) -> Tensor    # log-prob at sampling time

Behind generate sits paged storage and (in lesson 21) the scheduler. The framework cares about three production-shaped properties:

Takeaway
Decode is memory-bandwidth-bound; the KV cache is the data structure that consumes most of that bandwidth. PagedAttention turns the cache from a slab into a paged virtual memory — eliminating fragmentation, enabling sharing, and lifting achievable batch size 4–8×. RadixAttention layers a prefix tree on top so that arbitrary shared subtrees are reused, not just block-aligned prompts. Everything in the next lesson — continuous batching, prefix caching, chunked prefill, speculative decoding — assumes this paged cache as its substrate.