Scaled dot-product attention — from scratch
Derive it once: a learned weighted sum over a sequence, where the weights come from a softmax over scaled inner products. The √d_k factor isn't a hack; it's required to keep the softmax in its useful regime. Once you've derived it, every variant — MHA, MQA, GQA, FlashAttention — is a one-line diff.
The motivating problem
You have a sequence of tokens x_1, x_2, …, x_N ∈ ℝ^d. You want each output y_i to be a function of the whole sequence — not just x_i. The MLP processes each x_i independently; the convolution processes a local window. Attention says: for each i, compute a weighted sum over all j, where the weights are data-dependent.
The question is what to put for α_{ij} and v_j. Three principles:
- α_{ij} should be high when x_i is "interested in" x_j — i.e., depend on both. Hence a dot-product of learned projections.
- α_{ij} should sum to 1 over j — it's a distribution over the sequence. Hence softmax.
- v_j should be a learned representation of x_j — different from the "what am I" representation. Hence the separate V projection.
Derive the formula
Learn three projections from each token's d-dim feature: queries Q, keys K, values V, each of dim d_k.
"Token i's interest in token j" is the dot product Q_i · K_j. Stack into a matrix:
Apply softmax row-wise to get a distribution per query token, then average the values:
That's it. Five lines of math; the entire transformer rests on this.
The shapes — memorise these
For a single-head attention on a single sequence:
Two non-obvious points:
- d_v can be different from d_k. In practice they're always equal because the output is residual-added back to x, which has dim d, but algorithmically there's no requirement.
- The N×N attention matrix is the cost driver. It's quadratic in sequence length, both in compute (O(N²·d_k)) and memory (O(N²)). This is why everything in serving land is about avoiding materialising the full A — FlashAttention, online softmax, etc.
Multi-head — what's actually different
One head computes one (Q, K, V) projection. Multi-head attention learns h independent projections in parallel, each of dim d_h = d / h, then concatenates:
Why does this help? Three views:
- Low-rank trick. Each head is a rank-d_h projection of the full d-dim feature. h heads together give h · d_h = d total rank, same as a single full-dim attention — but with the freedom to put different content in each subspace.
- Different "what should I attend to" patterns. Empirically, one head attends to the previous token, another to syntactic dependencies, another to the start-of-sequence token, etc. Interpretability work (Anthropic, Olah et al.) shows that heads specialise.
- Parallelism. The h heads compute independently, so the matmul is one tensor operation of shape (h, N, d_h). On a GPU, this is exactly the workload tensor cores like.
Causal masking — the autoregressive constraint
For a language model, token i mustn't attend to tokens j > i (no future-information leak). Implementation:
Three implementations:
- Add a mask matrix. Compute the full S, then add M_{ij} = 0 if j ≤ i, −∞ otherwise, before softmax. Simplest. Materialises the full N×N matrix.
- FlashAttention-style streaming. Skip the upper-triangle tile entirely. Half the compute saved. Only works with a custom kernel.
- For inference: token N+1 only attends to tokens 1..N+1, so you only need to compute the (N+1)-th row of S. Combined with the KV cache (next section), this is what makes autoregressive decode fast.
KV cache — why decode is memory-bound
During autoregressive generation, you generate tokens one at a time. At step t, you have t tokens. Naïvely, you'd recompute Q, K, V for all of them every step → O(t²) compute per token → O(N³) per sequence.
The fix: K and V don't depend on later tokens. So you cache them. At step t:
- Compute Q only for the new token x_t — shape (1, d_k).
- Compute K, V only for x_t — shape (1, d_k) each. Append to the cache.
- Compute attention: Q_t · K_{1:t}^⊤, softmax, weighted sum of V_{1:t}.
Now each token costs O(t·d_k), not O(t²·d_k). The cost is memory: you store K, V for every layer and every head.
MHA → MQA → GQA — the KV cache shrink trick
Standard multi-head attention has h heads, each with its own K and V. The KV cache grows with h. But: queries are computed fresh every step, while keys and values sit in the cache. So:
| Variant | Q heads | K, V heads | KV cache vs MHA | Quality vs MHA |
|---|---|---|---|---|
| MHA (multi-head) | h | h | 1× | baseline |
| MQA (multi-query, Shazeer 2019) | h | 1 | 1/h × | noticeable degradation |
| GQA (grouped-query, Ainslie et al. 2023) | h | g (typically 8) | g/h × (e.g. 1/8 for h=64) | ~MHA |
In GQA, the h query heads are split into g groups; each group shares K, V. g is a knob between MHA (g=h) and MQA (g=1). LLaMA-2-70B uses g=8 with h=64 — an 8× KV cache reduction at near-zero quality loss.
Complexity — the load-bearing numbers
Per layer, per sequence, in standard MHA:
| Stage | FLOPs | Memory |
|---|---|---|
| QKV projections (3 of them) | 3 · N · d · d = 3·N·d² | 3·N·d (intermediates) |
| S = Q·K^⊤ | N² · d | N² (S matrix) |
| softmax(S) | N² · h (elementwise) | N² (A matrix, same as S) |
| Y = A · V | N² · d | N · d (output) |
| Output projection W_O | N · d² | N · d |
| Total | 4·N·d² + 2·N²·d | O(N·d + N²) |
Two regimes:
- Short sequences (N < d/4). N·d² dominates. Attention is essentially a matmul → tensor-core bound.
- Long sequences (N ≳ 2d). N²·d dominates. The attention matrix itself is the bottleneck → HBM bandwidth bound (loading the N² matrix).
(Crossover where 2·N²·d ≈ 4·N·d² is at N ≈ 2d.) For typical LLMs with d ≈ 8192 and N ≈ 4096, projections dominate. For long-context models (N ≈ 32k or 128k), N² dominates by a lot — this is the regime where FlashAttention and similar tricks pay off most.
Interactive · KV cache sizing
The interview probes — what makes attention work
- Where does multi-head attention's expressiveness come from? From the fact that the softmax is applied within each head's subspace, not over the full d-dim score. The softmax non-linearity acts on each subspace independently → you can attend to different things in different subspaces.
- Why is the output projection W_O separate from the head concatenation? Each head produces a d_h-dim output; concatenation gives d-dim. W_O then mixes the heads. Without W_O, the heads can't exchange information — each one is a disjoint summary. With W_O, the next layer's input is a learned mixture of heads.
- What goes wrong if you remove √d_k? The dot products have variance d_k. For d_k = 64, std = 8. exp(8) ≈ 3000 → softmax becomes effectively one-hot for almost all rows. The model can learn around this (e.g., make Q, K small), but it'll do it slowly. With the √d_k factor, the model trains 5–10× faster.
- Why is attention "permutation-equivariant"? If you permute the tokens of the input, the rows of Q, K, V permute together, so the attention output permutes the same way. The function doesn't care about position. To inject position, you need a positional encoding — next lesson.
- Why is KV cache the bottleneck, not Q computation? Q is recomputed every step from a single new token → 1 row of cost. K, V from every token so far are required at every step → t rows of cost. As context grows, KV reads dominate.
Where things go subtly wrong
| Bug | Symptom | Diagnosis |
|---|---|---|
| Mask off by one | Train loss looks normal; eval shows the model is "psychic" — can predict token t given token t. | Causal mask should be strict: j ≤ i if you want token i to use tokens 1..i inclusive of itself (standard). Or strict j < i if you want shifted. Pick one and be consistent across train and eval. |
| Softmax NaN | Loss is fine for a few thousand steps, then NaN. | Score has Inf (from -inf mask added to a position that is then read elsewhere?). Or attention scores have escaped to ±large values from bad init. |
| KV cache stale | Generated text doesn't change when you change the prompt. | You're re-using a cache from a previous request. Reset the cache between sequences. |
| Head dim too small under quantisation | Quantising to int8 hurts a lot. | Each head has d_h = d/h. For small d_h (e.g. 64), int8 quantisation has poor resolution. Larger d_h is more quant-robust → fewer heads, larger heads, GQA-style. |
1/math.sqrt(d_k) vs math.sqrt(1/d_k) |
Subtle training instability in fp16. | For small d_k, 1/d_k can underflow in fp16; 1/math.sqrt(d_k) is the safe form. Equivalent in fp32 but not in low precision. |
Interview prompts you should be ready for
- "Derive the √d_k scaling factor." (Q and K entries iid with var 1 → Q·K has var d_k, std √d_k. Without scaling, softmax saturates → vanishing gradients. Divide by √d_k to keep scores at unit variance.)
- "What's the FLOP cost of a single attention layer in a 70B LLaMA at N=4k?" (d=8192, h=64. Per token: QKV+O projections ≈ 4·d² = 268 MFLOPs. Across N=4k tokens: ~1.1 TFLOPs per layer. Attention matmul (Q·K^T + A·V across all heads): ~2·N²·d ≈ 274 GFLOPs per layer per sequence. At N=4k, projections dominate. At N≈2d=16k, attention catches up. At N=32k+, attention dominates — the regime FlashAttention was designed for.)
- "GQA: how does grouping save memory if compute is similar?" (GQA doesn't save compute on Q (still h heads of Q computed). It saves KV cache memory: g heads of K, V instead of h. For inference, that's the only thing that matters because KV cache dominates memory and you can serve more concurrent requests.)
- "Why isn't the V projection 'data-dependent' like Q and K?" (It is! V also depends on the token. The phrasing is "Q and K determine the weights; V is what's being weighted." V is the content you're aggregating; Q,K determine the aggregation pattern. Symmetric in structure, asymmetric in role.)
- "Causal mask in pseudocode — what's the shape and what value goes where?" (N×N matrix, M[i,j] = 0 if j ≤ i, -inf if j > i. Add to S before softmax. Equivalently, set A[i,j] = 0 for j > i after softmax, but that's not equivalent for variable-length sequences where you want the row to sum to 1.)
- "You're inferencing a 70B model with 64k context. What's the KV cache for one request?" (MHA: 4 × 80 × 8192 × 64000 × 1 = ~167 GB. Doesn't fit on one H100. GQA-8: 21 GB. Fits. This is why long-context is GQA territory.)
- "Cross-attention vs self-attention — what's different?" (Self-attention: Q, K, V all come from the same sequence. Cross-attention: Q comes from one sequence, K, V from another. Used in encoder-decoder transformers and in VLMs (Q is text, K/V is image). Architecturally identical; just different sources for the projections.)