vllm_lessons / 09 · GQA / MQA lesson 9 / 12

GQA / MQA

The attention-architecture change that shrinks KV cache by 4–8× with negligible quality loss. If you don't know this lesson, your serving math is wrong by an order of magnitude.

Why this lesson matters

Lesson 01 put one formula in your head:

bytes_per_token  =  2 · n_layers · n_kv_heads · head_dim · dtype_bytes

That bolded term — n_kv_heads — is the lever this lesson is about. In LLaMA-1 7B it equals the number of query heads (32). In LLaMA-2 70B it's 8. In Falcon-40B it's 1. Same family of models, same number of layers, KV cache per token differs by 64×.

If you compute "LLaMA-2 70B has 64 attention heads at d_head=128, so 80 layers × 64 × 128 × 2 × 2 = 2.6 MB/token" you've quietly assumed the term that no longer applies. The actual number is 320 KB/token. That difference is the whole story for what you can serve on one H100 node.

MHA recap

Standard multi-head attention with hidden size d, h heads, head dim k = d/h. For each head independently:

Qi = X · WQi,   Ki = X · WKi,   Vi = X · WVi   ∈  ℝ(T, k)

The KV cache stores all h key tensors and all h value tensors, one per layer:

KV per layer per sequence  =  2 · h · k · T · bytes_per_dtype

Halve h, halve the cache. That's the entire idea. The trick is finding the right h to halve.

The empirical observation

Queries across heads naturally diverge: head 3 attends to syntax, head 17 to coreference, head 24 to semantics. Routing matters; removing query heads kills quality fast.

Keys and values across heads, however, end up highly correlated — many heads' K vectors and V vectors land in roughly the same subspace. That redundancy costs HBM but earns nothing. The architecture change is: keep all the Q heads (they do the routing), shrink the K and V heads (they're redundant).

MQA · the maximal version (Shazeer 2019)

All h query heads share one key and one value:

Qi = X · WQi  (h heads),    K = X · WK,   V = X · WV  (one shared)
headi  =  softmax(Qi · KT / √k) · V

KV cache per layer drops to 2 · 1 · k · T · bytes — an reduction. On LLaMA-2 70B that's 64×. Quality cost: ~0.5–1 PPL on LLaMA-scale models. Real but small for many use cases, sometimes too aggressive for frontier quality.

GQA · the interpolation (Ainslie et al. 2023)

Introduce hkv KV heads with hkv < h. Each contiguous group of (h / hkv) query heads shares one (K, V) head:

Qi = X · WQi  (h heads),    Kj = X · WKj,   Vj = X · WVj  (hkv heads)
headi  =  softmax(Qi · Kgroup(i)T / √k) · Vgroup(i)

Two extremes:

KV reduction is h / hkv. LLaMA-2 70B picks h = 64, hkv = 8 → 8× smaller cache, quality essentially indistinguishable from MHA.

Shape flow side-by-side

Let B = batch, T = sequence, h = query heads, hkv = KV heads, k = head_dim.

MHA  (h_kv = h):
    Q: (B, h,   T, k)    K: (B, h,   T, k)    V: (B, h,   T, k)
    scores = Q @ K^T / √k                     (B, h, T, T)
    out    = scores @ V                       (B, h, T, k)

MQA  (h_kv = 1):
    Q: (B, h,   T, k)    K: (B, 1,   T, k)    V: (B, 1,   T, k)
    broadcast K, V over h: stride-0 expand    (B, h, T, k)   — no memcopy
    rest identical to MHA.

GQA  (h_kv = h/G, group_size G):
    Q: (B, h,    T, k)   K: (B, h_kv, T, k)   V: (B, h_kv, T, k)
    K = K.repeat_interleave(G, dim=1)         (B, h, T, k)   — materializes a copy in pure PyTorch
    V = V.repeat_interleave(G, dim=1)                          (fused kernels broadcast inside the tile loop instead)
    rest identical to MHA.

The broadcast is the key trick. In pure PyTorch, repeat_interleave materializes a copy of the (B, h_kv, T, k) tensor up to (B, h, T, k). Production fused attention kernels (FlashAttention's GQA path, vLLM's PagedAttention kernel) skip this materialization and broadcast K/V over the group inside the tile loop — so the per-token cache size matches hkv even though the math operates as if h heads contribute.

Why quality holds at hkv = 8

Two reasons, in order of importance:

  1. The K/V subspace is still wide. With hkv = 8, k = 128, your keys live in a 1024-dim space per token. That's plenty of capacity to express "what can be retrieved from this position." MQA's collapse to a 128-dim space is where the PPL loss starts.
  2. Queries stay fully multi-headed. The routing — what the model decides to attend to — happens in Q-space. K is what's available; Q is what's asked for. Sharing K across a Q-group means those queries see the same retrieval set, but each can weight it independently. That's the same content from different angles, which is exactly what attention is supposed to do.

The numbers · models you actually serve

Modeln_layershhkvhead_dimKV/tok (fp16)KV/tok (fp8)vs naive-MHA
LLaMA-1 7B (MHA) 323232128512 KB256 KB
LLaMA-2 70B (GQA G=8) 80648 128320 KB160 KB
Mistral 7B (GQA G=4) 32328 128128 KB64 KB
Gemma-2 9B (GQA G=2) 42168 256336 KB168 KB
Falcon-40B (MQA) 60641 64 15 KB 7.5 KB64×

Two things to notice: (1) LLaMA-2 70B has a smaller per-token KV than LLaMA-1 7B, despite being ten times bigger; (2) Falcon-40B's KV is essentially free — that's MQA's whole appeal for very-long-context applications.

The serving impact, made concrete

LLaMA-2 70B on an 80 GB H100 node (TP=2). After weights (~70 GB shard each), activations, framework overhead, you have ~30 GB HBM for KV cache. Plug into the formula:

Naive-MHA (wrong):   2 · 80 · 64 · 128 · 2  =  2.6 MB / token  →  ~11k concurrent tokens
GQA (actual):   2 · 80 · 8 · 128 · 2  =  320 KB / token  →  ~93k concurrent tokens
The punchline
At 2048-token average context, the difference is "serve 5 concurrent sessions" vs "serve 45+". That's the same model, the same hardware, the same kernels — entirely the architecture choice in hkv.

Interactive · the head plan

Move the sliders to design your own attention. The top row of boxes is the h query heads; the bottom row is the hkv KV heads. Lines show which Q heads share a (K, V). The KPI panel reports cache size and reduction relative to MHA at the same h.

Q/KV head plan
h_kv is clamped to a divisor of h (the only legal choices). Try h=64, h_kv=8 — that's LLaMA-2 70B.
group size
flavor
KV / tok / layer
vs MHA reduction

Pick layer count, head_dim, dtype to evaluate against real model configs. The table recomputes per-token KV using only what's in the cache (hkv).

presethh_kvKV/tok (your dtype)per 2k seq
show the formula this widget uses
// only h_kv heads sit in the KV cache, not h
const per_tok = 2 * n_layers * h_kv * head_dim * dtype_bytes;
const reduction_vs_mha = h / h_kv;
const group_size = h / h_kv;

Interactive · what gets read from HBM at decode time

The plan above tells you who shares what. The picture below tells you what actually moves. Press play to walk a single decode step:

  1. The full Q (all h heads) is computed from the one new input token.
  2. Only the physical hkv rows of K and V are read out of the cache — the boxes on the right.
  3. Each row is broadcast (stride-0, no copy) across its query group so every Q head can see the same K, V.
  4. Attention runs logically as if all h heads had their own K, V.

The left "logical" panel has h head rows that don't shrink as you slide hkv. The right "physical" panel does. That gap is the win — same compute graph, fewer bytes moved.

Decode-step flow · physical KV (right) → broadcast → logical attention (left)
Slide h, h_kv to redesign the architecture; press play to watch one token's worth of bytes travel through.
physical K+V bytes / tok / layer
logical attention heads
stride-0 broadcast factor
flavor
show the broadcast trick in code
# physical K_cache shape: (n_layers, h_kv, T, head_dim)
# logical view used by the attention kernel: (n_layers, h, T, head_dim)
# How it's done — without ever copying:

K_logical = K_cache.unsqueeze(2).expand(n_layers, h_kv, group_size, T, head_dim)
            .reshape(n_layers, h, T, head_dim)        # stride-0 along the group axis

# In PyTorch this expand() is free; the .reshape would force a materialization,
# so fused kernels (FlashAttention's GQA path, vLLM's PagedAttention kernel)
# skip the reshape and broadcast inside the tile loop. Cache stays at h_kv.

Putting it together

Takeaway
GQA replaces h with hkv in the lesson-01 formula, period. Everything else — the broadcast, the attention math, the kernel shape — is unchanged. It's not a clever serving trick; it's a one-line architecture change at training time that gives you 4–8× more concurrent sequences for free at inference time. Every modern open model uses it.