RadixAttention — the radix tree of KV
A radix tree whose nodes carry refcounts and KV-block pointers. Inserts, lookups, and evictions are O(token-length), and the structure is a strict generalization of every prefix cache that came before it.
One picture, then the rest of the lesson explains it
Three properties of the figure are load-bearing:
- Edges (not nodes) hold the tokens. Each edge owns a contiguous token run. Each edge also owns the KV blocks for that run — one block per
block_sizetokens. The KV pool is unchanged; the tree is an index over it. - Refcounts climb the tree. A request occupying a leaf increments the refcount of every node on its root-to-leaf path. When the request finishes, decrements walk back up. Blocks are returned to the free pool when refcount hits zero.
- Branches happen exactly where prompts diverge. Two requests that share 1,587 tokens get a single edge of length 1,587 from root; the 1,588th differing token forces a node split.
The node, in 20 lines
class RadixNode:
def __init__(self, parent, token_ids, block_ids):
self.parent = parent
self.children = {} # first_token_id → RadixNode
self.token_ids = token_ids # tokens on the edge into this node
self.block_ids = block_ids # KV pool block ids, len == ceil(len(token_ids)/B)
self.refcount = 0 # how many active requests hold this node
self.last_use = time.time()
# The pool is the same paged pool from vLLM lesson 02:
class KVPool:
free: list[int]; refcount: list[int]
def allocate(self, n): ...
def free_block(self, bid): ...
That's it. A radix tree is a dictionary first_token_id → child at each level; the parent pointer enables fast LRU eviction (next section); the rest is bookkeeping.
The four operations
1. Lookup (longest prefix match)
def lookup(root, tokens):
"""Return (matched_tokens, deepest_node). Walks the tree, eating tokens."""
node, i = root, 0
while i < len(tokens):
ch = node.children.get(tokens[i])
if ch is None: break
edge = ch.token_ids
j = common_prefix_length(tokens[i:], edge)
if j == len(edge):
node, i = ch, i + j # walked through the whole edge
else:
return i + j, (ch, j) # diverged mid-edge — must split on insert
return i, (node, len(node.token_ids))
For an incoming prompt of length T, lookup costs O(T) token comparisons in the worst case — but with branching factor B and prompt depth D, it's typically O(D log B). The deepest node returned is the one whose blocks the request will reuse; matched_tokens is how many it doesn't have to prefill.
2. Insert
The new request runs lookup, then writes a child edge for its un-matched suffix. If the divergence point was mid-edge, the existing edge is split at that point first:
Notice what the split does to the physical KV: nothing. The blocks for tokens 0..1586 stay where they are; the new branching node simply holds the pointers to that block run. The original edge's block list is sliced — the first 99 blocks belong to the parent edge now, the next 58 to the "old tail" edge. No copy.
3. Increment / decrement (request lifecycle)
When a request is admitted, walk from its leaf to the root incrementing refcount. When it finishes, decrement. A node with refcount 0 is evictable but not necessarily evicted — it stays in the tree as a cache.
4. Eviction (LRU at leaves)
When the KV pool needs more blocks than are free, the tree finds an evictable leaf (refcount 0) with the oldest last_use timestamp and frees its blocks. After freeing, if the parent has no other children and is itself refcount 0, evict it too — and recurse upward. This is what keeps the tree small over time: idle leaves and their unique-prefix chains collapse out.
def evict_one():
# min-heap keyed by (last_use, -depth) over leaves with refcount==0
leaf = leaves_heap.pop_min()
pool.free_blocks(leaf.block_ids)
p = leaf.parent
del p.children[leaf.token_ids[0]]
while p.refcount == 0 and not p.children and p is not root:
pool.free_blocks(p.block_ids)
grand = p.parent
del grand.children[p.token_ids[0]]
p = grand
Why LRU at leaves specifically? Interior nodes are always shared by at least two subtrees. Evicting them would force a re-prefill on every future request whose prefix passes through. Leaves are the only nodes whose eviction harms at most one future request.
Copy-on-write — the rarely-needed escape valve
What if two requests share a leaf and one of them keeps decoding? The shared node had at the time of sharing a fixed token range. As soon as one sharer produces a new token, that token (and its KV block) must not be visible to the other sharer.
The radix tree handles this by always appending new tokens to a new child node. The shared node is read-only; the diverging request creates a new leaf below it, allocates a fresh block, writes there. No copy of the parent is ever needed — because each request's decode creates a fresh branch by construction.
Cost analysis
| Operation | Time | Where it lives |
|---|---|---|
| lookup(tokens) | worst case O(T) compares; typical O(depth + avg edge length), fast | CPU, per admission |
| insert(tokens) | O(T) + 1 split | CPU, per admission |
| increment/decrement refcount | O(depth) | CPU, per lifecycle event |
| evict_one() | O(log L) heap + O(depth) bubble | CPU, only when pool is tight |
| attention kernel read | unchanged | GPU |
The point of this table: every cost is on the CPU and proportional to token count or depth, not to GPU work. For Llama-3-70B, one attention layer kernel runs in ~10–30 µs, and a full decode step (all 80 layers + sampling) is ~5–15 ms. Admitting a new request through the tree is well under 100 µs total. The radix tree is not on the hot path.
The block-table assembly step
The attention kernel still wants a flat block_table[i] for token i. The runtime assembles this once per request by concatenating the block_ids on its root-to-leaf path:
def assemble_block_table(leaf):
blocks = []
node = leaf
while node is not root:
blocks = node.block_ids + blocks # prepend (root-first order)
node = node.parent
return blocks
This runs once at admission, plus once per decode batch (to extend by one block when a sequence crosses a block boundary). For the running 32-fork example: 32 requests, each assembling a block_table of ~160 entries by walking 2 nodes. Microseconds.
What it gives the rest of the system
- Prefix sharing across calls, time, and partial overlap. The tree doesn't care whether two prompts arrived 1 ms or 1 hour apart — if their tokens line up, they share. (Subject to whether the relevant nodes have been evicted.)
- Cache state is queryable. The scheduler can ask "what's the deepest existing match for this incoming prompt?" before admitting it (lesson 05).
- Memory is bounded by unique prefix bytes, not by total request bytes. A million identical requests cost the tree no more than one.
- Cleanly composes with paging. Block size, gpu_memory_utilization, and max_model_len all still mean what they meant in vLLM.
Interactive · watch the tree grow and shrink
Each click submits a request whose first shared tokens are drawn from a small pool of system prompts and whose suffix is unique. Watch the tree structure and the KV usage as requests arrive and finish.