GQA / MQA
The attention-architecture change that shrinks KV cache by 4–8× with negligible quality loss. If you don't know this lesson, your serving math is wrong by an order of magnitude.
Why this lesson matters
Lesson 01 put one formula in your head:
That bolded term — n_kv_heads — is the lever this lesson is about. In LLaMA-1 7B it equals the number of query heads (32). In LLaMA-2 70B it's 8. In Falcon-40B it's 1. Same family of models, same number of layers, KV cache per token differs by 64×.
If you compute "LLaMA-2 70B has 64 attention heads at d_head=128, so 80 layers × 64 × 128 × 2 × 2 = 2.6 MB/token" you've quietly assumed the term that no longer applies. The actual number is 320 KB/token. That difference is the whole story for what you can serve on one H100 node.
MHA recap
Standard multi-head attention with hidden size d, h heads, head dim k = d/h. For each head independently:
The KV cache stores all h key tensors and all h value tensors, one per layer:
Halve h, halve the cache. That's the entire idea. The trick is finding the right h to halve.
The empirical observation
Queries across heads naturally diverge: head 3 attends to syntax, head 17 to coreference, head 24 to semantics. Routing matters; removing query heads kills quality fast.
Keys and values across heads, however, end up highly correlated — many heads' K vectors and V vectors land in roughly the same subspace. That redundancy costs HBM but earns nothing. The architecture change is: keep all the Q heads (they do the routing), shrink the K and V heads (they're redundant).
MQA · the maximal version (Shazeer 2019)
All h query heads share one key and one value:
KV cache per layer drops to 2 · 1 · k · T · bytes — an h× reduction. On LLaMA-2 70B that's 64×. Quality cost: ~0.5–1 PPL on LLaMA-scale models. Real but small for many use cases, sometimes too aggressive for frontier quality.
GQA · the interpolation (Ainslie et al. 2023)
Introduce hkv KV heads with hkv < h. Each contiguous group of (h / hkv) query heads shares one (K, V) head:
Two extremes:
- hkv = h → group_size 1 → exactly MHA.
- hkv = 1 → group_size h → exactly MQA.
KV reduction is h / hkv. LLaMA-2 70B picks h = 64, hkv = 8 → 8× smaller cache, quality essentially indistinguishable from MHA.
Shape flow side-by-side
Let B = batch, T = sequence, h = query heads, hkv = KV heads, k = head_dim.
MHA (h_kv = h):
Q: (B, h, T, k) K: (B, h, T, k) V: (B, h, T, k)
scores = Q @ K^T / √k (B, h, T, T)
out = scores @ V (B, h, T, k)
MQA (h_kv = 1):
Q: (B, h, T, k) K: (B, 1, T, k) V: (B, 1, T, k)
broadcast K, V over h: stride-0 expand (B, h, T, k) — no memcopy
rest identical to MHA.
GQA (h_kv = h/G, group_size G):
Q: (B, h, T, k) K: (B, h_kv, T, k) V: (B, h_kv, T, k)
K = K.repeat_interleave(G, dim=1) (B, h, T, k) — materializes a copy in pure PyTorch
V = V.repeat_interleave(G, dim=1) (fused kernels broadcast inside the tile loop instead)
rest identical to MHA.
The broadcast is the key trick. In pure PyTorch, repeat_interleave materializes a copy of the (B, h_kv, T, k) tensor up to (B, h, T, k). Production fused attention kernels (FlashAttention's GQA path, vLLM's PagedAttention kernel) skip this materialization and broadcast K/V over the group inside the tile loop — so the per-token cache size matches hkv even though the math operates as if h heads contribute.
Why quality holds at hkv = 8
Two reasons, in order of importance:
- The K/V subspace is still wide. With hkv = 8, k = 128, your keys live in a 1024-dim space per token. That's plenty of capacity to express "what can be retrieved from this position." MQA's collapse to a 128-dim space is where the PPL loss starts.
- Queries stay fully multi-headed. The routing — what the model decides to attend to — happens in Q-space. K is what's available; Q is what's asked for. Sharing K across a Q-group means those queries see the same retrieval set, but each can weight it independently. That's the same content from different angles, which is exactly what attention is supposed to do.
The numbers · models you actually serve
| Model | n_layers | h | hkv | head_dim | KV/tok (fp16) | KV/tok (fp8) | vs naive-MHA |
|---|---|---|---|---|---|---|---|
| LLaMA-1 7B (MHA) | 32 | 32 | 32 | 128 | 512 KB | 256 KB | 1× |
| LLaMA-2 70B (GQA G=8) | 80 | 64 | 8 | 128 | 320 KB | 160 KB | 8× |
| Mistral 7B (GQA G=4) | 32 | 32 | 8 | 128 | 128 KB | 64 KB | 4× |
| Gemma-2 9B (GQA G=2) | 42 | 16 | 8 | 256 | 336 KB | 168 KB | 2× |
| Falcon-40B (MQA) | 60 | 64 | 1 | 64 | 15 KB | 7.5 KB | 64× |
Two things to notice: (1) LLaMA-2 70B has a smaller per-token KV than LLaMA-1 7B, despite being ten times bigger; (2) Falcon-40B's KV is essentially free — that's MQA's whole appeal for very-long-context applications.
The serving impact, made concrete
LLaMA-2 70B on an 80 GB H100 node (TP=2). After weights (~70 GB shard each), activations, framework overhead, you have ~30 GB HBM for KV cache. Plug into the formula:
Interactive · the head plan
Move the sliders to design your own attention. The top row of boxes is the h query heads; the bottom row is the hkv KV heads. Lines show which Q heads share a (K, V). The KPI panel reports cache size and reduction relative to MHA at the same h.
Interactive · what gets read from HBM at decode time
The plan above tells you who shares what. The picture below tells you what actually moves. Press play to walk a single decode step:
- The full Q (all h heads) is computed from the one new input token.
- Only the physical hkv rows of K and V are read out of the cache — the boxes on the right.
- Each row is broadcast (stride-0, no copy) across its query group so every Q head can see the same K, V.
- Attention runs logically as if all h heads had their own K, V.
The left "logical" panel has h head rows that don't shrink as you slide hkv. The right "physical" panel does. That gap is the win — same compute graph, fewer bytes moved.