deep_learning / 04 · attention lesson 4 / 12

Scaled dot-product attention — from scratch

Derive it once: a learned weighted sum over a sequence, where the weights come from a softmax over scaled inner products. The √d_k factor isn't a hack; it's required to keep the softmax in its useful regime. Once you've derived it, every variant — MHA, MQA, GQA, FlashAttention — is a one-line diff.

The motivating problem

You have a sequence of tokens x_1, x_2, …, x_N ∈ ℝ^d. You want each output y_i to be a function of the whole sequence — not just x_i. The MLP processes each x_i independently; the convolution processes a local window. Attention says: for each i, compute a weighted sum over all j, where the weights are data-dependent.

y_i = Σ_j α_{ij} · v_j

The question is what to put for α_{ij} and v_j. Three principles:

  1. α_{ij} should be high when x_i is "interested in" x_j — i.e., depend on both. Hence a dot-product of learned projections.
  2. α_{ij} should sum to 1 over j — it's a distribution over the sequence. Hence softmax.
  3. v_j should be a learned representation of x_j — different from the "what am I" representation. Hence the separate V projection.

Derive the formula

Learn three projections from each token's d-dim feature: queries Q, keys K, values V, each of dim d_k.

Q = X W_Q, K = X W_K, V = X W_V   — X ∈ ℝ^{N × d}, W_* ∈ ℝ^{d × d_k}

"Token i's interest in token j" is the dot product Q_i · K_j. Stack into a matrix:

S = Q K^⊤ ∈ ℝ^{N × N}   — S_{ij} is token i's score against token j

Apply softmax row-wise to get a distribution per query token, then average the values:

A = softmax(S / √d_k) Y = A V   — each row of Y is a weighted average of V

That's it. Five lines of math; the entire transformer rests on this.

Why √d_k — derive it
Suppose Q_i and K_j have entries that are independent with variance 1. Their dot product Q_i · K_j = Σ_k Q_{ik} K_{jk} is a sum of d_k independent terms each with variance 1 → total variance d_k, std √d_k. The softmax exp(·) saturates when its input has magnitude >> 1. So without division, as d_k grows, the softmax becomes increasingly one-hot, the gradient of softmax shrinks to ~0 everywhere except the max, and learning stalls. Dividing by √d_k keeps the scores at unit variance regardless of d_k. The interviewer's tell: candidates who say "it normalises" without deriving the variance argument haven't actually thought about it.

The shapes — memorise these

For a single-head attention on a single sequence:

X [N, d] W_Q [d, d_k] → Q [N, d_k] W_K [d, d_k] → K [N, d_k] W_V [d, d_v] → V [N, d_v] S = Q K^⊤ [N, N] A = softmax(S/√d_k) [N, N] Y = A V [N, d_v]

Two non-obvious points:

  1. d_v can be different from d_k. In practice they're always equal because the output is residual-added back to x, which has dim d, but algorithmically there's no requirement.
  2. The N×N attention matrix is the cost driver. It's quadratic in sequence length, both in compute (O(N²·d_k)) and memory (O(N²)). This is why everything in serving land is about avoiding materialising the full A — FlashAttention, online softmax, etc.

Multi-head — what's actually different

One head computes one (Q, K, V) projection. Multi-head attention learns h independent projections in parallel, each of dim d_h = d / h, then concatenates:

For head i: Q_i = X W_Q^{(i)}, K_i = X W_K^{(i)}, V_i = X W_V^{(i)}   each [N, d_h] Y_i = softmax(Q_i K_i^⊤ / √d_h) V_i Y = concat(Y_1, …, Y_h) W_O   [N, d]

Why does this help? Three views:

  1. Low-rank trick. Each head is a rank-d_h projection of the full d-dim feature. h heads together give h · d_h = d total rank, same as a single full-dim attention — but with the freedom to put different content in each subspace.
  2. Different "what should I attend to" patterns. Empirically, one head attends to the previous token, another to syntactic dependencies, another to the start-of-sequence token, etc. Interpretability work (Anthropic, Olah et al.) shows that heads specialise.
  3. Parallelism. The h heads compute independently, so the matmul is one tensor operation of shape (h, N, d_h). On a GPU, this is exactly the workload tensor cores like.
A subtle interview probe
Multi-head attention has the same total parameter count as a single big-head attention with d_k = d. The projection matrices W_Q, W_K, W_V for one head are (d × d/h); for h heads stacked they're (d × d). No new parameters. The work happens because the softmax is now applied within each head's subspace, not over the full d-dim score, which is a different non-linearity.

Causal masking — the autoregressive constraint

For a language model, token i mustn't attend to tokens j > i (no future-information leak). Implementation:

S_{ij} = Q_i · K_j / √d_k if j ≤ i S_{ij} = −∞ if j > i A = softmax(S)   — the −∞ becomes 0 after exp

Three implementations:

  1. Add a mask matrix. Compute the full S, then add M_{ij} = 0 if j ≤ i, −∞ otherwise, before softmax. Simplest. Materialises the full N×N matrix.
  2. FlashAttention-style streaming. Skip the upper-triangle tile entirely. Half the compute saved. Only works with a custom kernel.
  3. For inference: token N+1 only attends to tokens 1..N+1, so you only need to compute the (N+1)-th row of S. Combined with the KV cache (next section), this is what makes autoregressive decode fast.

KV cache — why decode is memory-bound

During autoregressive generation, you generate tokens one at a time. At step t, you have t tokens. Naïvely, you'd recompute Q, K, V for all of them every step → O(t²) compute per token → O(N³) per sequence.

The fix: K and V don't depend on later tokens. So you cache them. At step t:

  1. Compute Q only for the new token x_t — shape (1, d_k).
  2. Compute K, V only for x_t — shape (1, d_k) each. Append to the cache.
  3. Compute attention: Q_t · K_{1:t}^⊤, softmax, weighted sum of V_{1:t}.

Now each token costs O(t·d_k), not O(t²·d_k). The cost is memory: you store K, V for every layer and every head.

KV cache size = 2 · layers · heads · d_h · max_seq_len · batch · bytes_per_value = 2 · L · n_h · d_h · S · B · 2   (bf16) = 4 · L · d · S · B   (since n_h · d_h = d)
Why KV cache size is the load-bearing number for serving
Worked MHA example: layers=80, d=8192, max sequence 4096, batch=1, bf16. KV cache = 4 × 80 × 8192 × 4096 × 1 ≈ 10.7 GB per request. (LLaMA-2-70B itself uses GQA-8, which is 8× cheaper — ~1.3 GB — but the MHA arithmetic is what motivated GQA in the first place.) On an 80 GB H100 holding a 70B-bf16 model, KV space is the binding constraint on concurrent requests. Doubling context doubles KV. This is the entire motivation for MQA/GQA (next section), paged KV (vLLM), and 8-bit KV cache.

MHA → MQA → GQA — the KV cache shrink trick

Standard multi-head attention has h heads, each with its own K and V. The KV cache grows with h. But: queries are computed fresh every step, while keys and values sit in the cache. So:

VariantQ headsK, V headsKV cache vs MHAQuality vs MHA
MHA (multi-head) hhbaseline
MQA (multi-query, Shazeer 2019) h11/h ×noticeable degradation
GQA (grouped-query, Ainslie et al. 2023) hg (typically 8)g/h × (e.g. 1/8 for h=64)~MHA

In GQA, the h query heads are split into g groups; each group shares K, V. g is a knob between MHA (g=h) and MQA (g=1). LLaMA-2-70B uses g=8 with h=64 — an 8× KV cache reduction at near-zero quality loss.

MHA (h=4) MQA (h=4, kv=1) GQA (h=4, g=2) Q1 Q2 Q3 Q4 Q1 Q2 Q3 Q4 Q1 Q2 Q3 Q4 │ │ │ │ │ │ │ │ │ │ │ │ K1 K2 K3 K4 K (shared) K1 K2 V1 V2 V3 V4 V (shared) V1 V2

Complexity — the load-bearing numbers

Per layer, per sequence, in standard MHA:

StageFLOPsMemory
QKV projections (3 of them)3 · N · d · d = 3·N·d²3·N·d (intermediates)
S = Q·K^⊤N² · dN² (S matrix)
softmax(S)N² · h (elementwise)N² (A matrix, same as S)
Y = A · VN² · dN · d (output)
Output projection W_ON · d²N · d
Total4·N·d² + 2·N²·dO(N·d + N²)

Two regimes:

(Crossover where 2·N²·d ≈ 4·N·d² is at N ≈ 2d.) For typical LLMs with d ≈ 8192 and N ≈ 4096, projections dominate. For long-context models (N ≈ 32k or 128k), N² dominates by a lot — this is the regime where FlashAttention and similar tricks pay off most.

Interactive · KV cache sizing

KV cache size for an LLM
Pick model dims and attention variant. The widget gives the per-request and per-batch KV memory, and tells you how many concurrent requests fit on an H100 (80 GB).
KV / token
KV / request
KV / batch
concurrent reqs on H100
Reading

The interview probes — what makes attention work

  1. Where does multi-head attention's expressiveness come from? From the fact that the softmax is applied within each head's subspace, not over the full d-dim score. The softmax non-linearity acts on each subspace independently → you can attend to different things in different subspaces.
  2. Why is the output projection W_O separate from the head concatenation? Each head produces a d_h-dim output; concatenation gives d-dim. W_O then mixes the heads. Without W_O, the heads can't exchange information — each one is a disjoint summary. With W_O, the next layer's input is a learned mixture of heads.
  3. What goes wrong if you remove √d_k? The dot products have variance d_k. For d_k = 64, std = 8. exp(8) ≈ 3000 → softmax becomes effectively one-hot for almost all rows. The model can learn around this (e.g., make Q, K small), but it'll do it slowly. With the √d_k factor, the model trains 5–10× faster.
  4. Why is attention "permutation-equivariant"? If you permute the tokens of the input, the rows of Q, K, V permute together, so the attention output permutes the same way. The function doesn't care about position. To inject position, you need a positional encoding — next lesson.
  5. Why is KV cache the bottleneck, not Q computation? Q is recomputed every step from a single new token → 1 row of cost. K, V from every token so far are required at every step → t rows of cost. As context grows, KV reads dominate.

Where things go subtly wrong

BugSymptomDiagnosis
Mask off by one Train loss looks normal; eval shows the model is "psychic" — can predict token t given token t. Causal mask should be strict: j ≤ i if you want token i to use tokens 1..i inclusive of itself (standard). Or strict j < i if you want shifted. Pick one and be consistent across train and eval.
Softmax NaN Loss is fine for a few thousand steps, then NaN. Score has Inf (from -inf mask added to a position that is then read elsewhere?). Or attention scores have escaped to ±large values from bad init.
KV cache stale Generated text doesn't change when you change the prompt. You're re-using a cache from a previous request. Reset the cache between sequences.
Head dim too small under quantisation Quantising to int8 hurts a lot. Each head has d_h = d/h. For small d_h (e.g. 64), int8 quantisation has poor resolution. Larger d_h is more quant-robust → fewer heads, larger heads, GQA-style.
1/math.sqrt(d_k) vs math.sqrt(1/d_k) Subtle training instability in fp16. For small d_k, 1/d_k can underflow in fp16; 1/math.sqrt(d_k) is the safe form. Equivalent in fp32 but not in low precision.

Interview prompts you should be ready for

  1. "Derive the √d_k scaling factor." (Q and K entries iid with var 1 → Q·K has var d_k, std √d_k. Without scaling, softmax saturates → vanishing gradients. Divide by √d_k to keep scores at unit variance.)
  2. "What's the FLOP cost of a single attention layer in a 70B LLaMA at N=4k?" (d=8192, h=64. Per token: QKV+O projections ≈ 4·d² = 268 MFLOPs. Across N=4k tokens: ~1.1 TFLOPs per layer. Attention matmul (Q·K^T + A·V across all heads): ~2·N²·d ≈ 274 GFLOPs per layer per sequence. At N=4k, projections dominate. At N≈2d=16k, attention catches up. At N=32k+, attention dominates — the regime FlashAttention was designed for.)
  3. "GQA: how does grouping save memory if compute is similar?" (GQA doesn't save compute on Q (still h heads of Q computed). It saves KV cache memory: g heads of K, V instead of h. For inference, that's the only thing that matters because KV cache dominates memory and you can serve more concurrent requests.)
  4. "Why isn't the V projection 'data-dependent' like Q and K?" (It is! V also depends on the token. The phrasing is "Q and K determine the weights; V is what's being weighted." V is the content you're aggregating; Q,K determine the aggregation pattern. Symmetric in structure, asymmetric in role.)
  5. "Causal mask in pseudocode — what's the shape and what value goes where?" (N×N matrix, M[i,j] = 0 if j ≤ i, -inf if j > i. Add to S before softmax. Equivalently, set A[i,j] = 0 for j > i after softmax, but that's not equivalent for variable-length sequences where you want the row to sum to 1.)
  6. "You're inferencing a 70B model with 64k context. What's the KV cache for one request?" (MHA: 4 × 80 × 8192 × 64000 × 1 = ~167 GB. Doesn't fit on one H100. GQA-8: 21 GB. Fits. This is why long-context is GQA territory.)
  7. "Cross-attention vs self-attention — what's different?" (Self-attention: Q, K, V all come from the same sequence. Cross-attention: Q comes from one sequence, K, V from another. Used in encoder-decoder transformers and in VLMs (Q is text, K/V is image). Architecturally identical; just different sources for the projections.)
Takeaway
Attention is a learned weighted sum where the weights come from softmax(QK^⊤ / √d_k). Multi-head splits the projection into parallel subspaces. Causal masking enforces the autoregressive constraint. The KV cache makes generation O(N) per token instead of O(N²). GQA/MQA shrink the cache at minimal quality cost. Every detail — √d_k, masking, output projection — has a precise reason. The interview signal is being able to re-derive the formula and explain what would break if you removed any single piece.