all_lessons / sglang / 04 · RadixAttention lesson 4 / 11

RadixAttention — the radix tree of KV

A radix tree whose nodes carry refcounts and KV-block pointers. Inserts, lookups, and evictions are O(token-length), and the structure is a strict generalization of every prefix cache that came before it.

One picture, then the rest of the lesson explains it

root tokens[0:2500] · "You are…" sys-prompt node refcount=32 · blocks=[03,04,…,159] edge length 2500 tok → 157 blocks @ bs=16 "\nUser: …" "\nUser: …" query A (85 tok) rc=3 · blocks=[201..206] in-flight query B (22 tok) rc=5 · blocks=[212,213] query C (19 tok) rc=24 · blocks=[220,221] decode #41 blocks=[250..] decode #42 decode #51 decode #87 edges store token runs; nodes store {refcount, block_ids, last_use_ts}.

Three properties of the figure are load-bearing:

  1. Edges (not nodes) hold the tokens. Each edge owns a contiguous token run. Each edge also owns the KV blocks for that run — one block per block_size tokens. The KV pool is unchanged; the tree is an index over it.
  2. Refcounts climb the tree. A request occupying a leaf increments the refcount of every node on its root-to-leaf path. When the request finishes, decrements walk back up. Blocks are returned to the free pool when refcount hits zero.
  3. Branches happen exactly where prompts diverge. Two requests that share 1,587 tokens get a single edge of length 1,587 from root; the 1,588th differing token forces a node split.

The node, in 20 lines

class RadixNode:
    def __init__(self, parent, token_ids, block_ids):
        self.parent     = parent
        self.children   = {}          # first_token_id → RadixNode
        self.token_ids  = token_ids   # tokens on the edge into this node
        self.block_ids  = block_ids   # KV pool block ids, len == ceil(len(token_ids)/B)
        self.refcount   = 0           # how many active requests hold this node
        self.last_use   = time.time()

# The pool is the same paged pool from vLLM lesson 02:
class KVPool:
    free: list[int]; refcount: list[int]
    def allocate(self, n): ...
    def free_block(self, bid): ...

That's it. A radix tree is a dictionary first_token_id → child at each level; the parent pointer enables fast LRU eviction (next section); the rest is bookkeeping.

The four operations

1. Lookup (longest prefix match)

def lookup(root, tokens):
    """Return (matched_tokens, deepest_node). Walks the tree, eating tokens."""
    node, i = root, 0
    while i < len(tokens):
        ch = node.children.get(tokens[i])
        if ch is None: break
        edge = ch.token_ids
        j = common_prefix_length(tokens[i:], edge)
        if j == len(edge):
            node, i = ch, i + j         # walked through the whole edge
        else:
            return i + j, (ch, j)       # diverged mid-edge — must split on insert
    return i, (node, len(node.token_ids))

For an incoming prompt of length T, lookup costs O(T) token comparisons in the worst case — but with branching factor B and prompt depth D, it's typically O(D log B). The deepest node returned is the one whose blocks the request will reuse; matched_tokens is how many it doesn't have to prefill.

2. Insert

The new request runs lookup, then writes a child edge for its un-matched suffix. If the divergence point was mid-edge, the existing edge is split at that point first:

before insert: one long edge [tok 0..2499] bs=157 after insert (diverges at tok 1587): split, then attach suffix [tok 0..1586] new branching node old tail [1587..2499] new suffix (the new request)

Notice what the split does to the physical KV: nothing. The blocks for tokens 0..1586 stay where they are; the new branching node simply holds the pointers to that block run. The original edge's block list is sliced — the first 99 blocks belong to the parent edge now, the next 58 to the "old tail" edge. No copy.

3. Increment / decrement (request lifecycle)

When a request is admitted, walk from its leaf to the root incrementing refcount. When it finishes, decrement. A node with refcount 0 is evictable but not necessarily evicted — it stays in the tree as a cache.

4. Eviction (LRU at leaves)

When the KV pool needs more blocks than are free, the tree finds an evictable leaf (refcount 0) with the oldest last_use timestamp and frees its blocks. After freeing, if the parent has no other children and is itself refcount 0, evict it too — and recurse upward. This is what keeps the tree small over time: idle leaves and their unique-prefix chains collapse out.

def evict_one():
    # min-heap keyed by (last_use, -depth) over leaves with refcount==0
    leaf = leaves_heap.pop_min()
    pool.free_blocks(leaf.block_ids)
    p = leaf.parent
    del p.children[leaf.token_ids[0]]
    while p.refcount == 0 and not p.children and p is not root:
        pool.free_blocks(p.block_ids)
        grand = p.parent
        del grand.children[p.token_ids[0]]
        p = grand

Why LRU at leaves specifically? Interior nodes are always shared by at least two subtrees. Evicting them would force a re-prefill on every future request whose prefix passes through. Leaves are the only nodes whose eviction harms at most one future request.

Copy-on-write — the rarely-needed escape valve

What if two requests share a leaf and one of them keeps decoding? The shared node had at the time of sharing a fixed token range. As soon as one sharer produces a new token, that token (and its KV block) must not be visible to the other sharer.

The radix tree handles this by always appending new tokens to a new child node. The shared node is read-only; the diverging request creates a new leaf below it, allocates a fresh block, writes there. No copy of the parent is ever needed — because each request's decode creates a fresh branch by construction.

Why "CoW" is (almost) a misnomer here
In OS virtual memory, two processes share a page until one writes — at which point the page is duplicated. RadixAttention mostly doesn't need that, because KV blocks are append-only: when a request's next token wants to extend past a shared node, it plants a new child edge (and allocates a fresh block for the new tokens) rather than mutating the shared block. The one subtle case: the last block of the shared prefix may be partially filled — e.g., the shared edge ends mid-block at slot 11/16, and the diverging request's first new token wants slot 12. Rather than write into slot 12 of the shared block (which would alias the writer's new KV onto every sharer), the runtime starts a fresh block at logical position 12 for that request. There's a small block-internal fragmentation cost (slots 12..15 of the shared block stay unused), but no copy. The cost is one node creation + one block allocation, never a page copy.

Cost analysis

OperationTimeWhere it lives
lookup(tokens)worst case O(T) compares; typical O(depth + avg edge length), fastCPU, per admission
insert(tokens)O(T) + 1 splitCPU, per admission
increment/decrement refcountO(depth)CPU, per lifecycle event
evict_one()O(log L) heap + O(depth) bubbleCPU, only when pool is tight
attention kernel readunchangedGPU

The point of this table: every cost is on the CPU and proportional to token count or depth, not to GPU work. For Llama-3-70B, one attention layer kernel runs in ~10–30 µs, and a full decode step (all 80 layers + sampling) is ~5–15 ms. Admitting a new request through the tree is well under 100 µs total. The radix tree is not on the hot path.

The block-table assembly step

The attention kernel still wants a flat block_table[i] for token i. The runtime assembles this once per request by concatenating the block_ids on its root-to-leaf path:

def assemble_block_table(leaf):
    blocks = []
    node = leaf
    while node is not root:
        blocks = node.block_ids + blocks   # prepend (root-first order)
        node = node.parent
    return blocks

This runs once at admission, plus once per decode batch (to extend by one block when a sequence crosses a block boundary). For the running 32-fork example: 32 requests, each assembling a block_table of ~160 entries by walking 2 nodes. Microseconds.

What it gives the rest of the system

Interactive · watch the tree grow and shrink

Each click submits a request whose first shared tokens are drawn from a small pool of system prompts and whose suffix is unique. Watch the tree structure and the KV usage as requests arrive and finish.

Radix tree under load

Submit a few of "agent" (long shared prefix) then a few of "raw" (no sharing). Note how raw requests bloat KV usage far more per-request than agent ones.