Transformer forward as a kernel chain
Lesson 01 said "count the bytes." Now we do it. A single decoder layer is six or seven recognizable kernels. By the end of this lesson the KV cache stops being a phrase and becomes a number — and you can predict which kernel saturates which resource.
The question this lesson answers
If a forward pass is just a chain of well-known operations, why does serving feel so model-specific? Because each operation hits the roofline (lesson 01) differently depending on whether the model is processing many input tokens at once or generating one new token at a time. The same matrix multiply that is compute-bound in the first regime is bandwidth-bound in the second. The KV cache, which we now define precisely, is the reason that asymmetry exists.
Two names we'll use from here on
Two phases of generation are referenced throughout the rest of the track. Lesson 03 will derive their different kernel needs in detail; here we just name them so the byte accounting lands.
- Prefill — the engine ingests the user's prompt (T tokens at once) and fills the KV cache. One forward pass over T tokens.
- Decode — the engine generates one new token. The model takes the single newest token, attends to all previous K/V, and emits a logit row. One forward pass per token.
Two minor pieces of model vocabulary also appear below:
- GQA (grouped-query attention) — several Q heads share one KV head. If the model has H query heads and K KV heads with K<H, KV reads are cut by H/K. Llama-3 uses K=8 with H=64.
- MLA (multi-latent attention, DeepSeek-style) — a different attention factorization that stores a compressed latent KV. We mention it only to keep the table honest; the math is out of scope here.
One decoder layer, drawn flat
For a standard Llama-style decoder layer, the kernels are: RMSNorm → QKV projection → rope → attention → output projection → residual add → RMSNorm → MLP gate/up → activation → MLP down → residual add. We list them with their byte and FLOP costs for one decode step (one new token) at batch size B for a model with hidden dim D, head dim d, H query heads, K KV heads, and intermediate size F. Weights are bf16 (2 bytes).
The KV cache, derived
Attention at decode step T reads every previous token's K and V tensors. Each token contributed K · d values to K and the same to V, in 2 bytes each, across every layer. So the bytes per token, per layer, is:
Multiply by the number of layers L to get bytes per token across the whole model. For Llama-3-70B with L=80, K=8, d=128, b=2:
For an 8K-context request that is 2.56 MB of KV. For 1024 such requests, 2.56 GB of HBM — and the decode kernel must read all of it, every layer, every token. This is the single most important number in serving. Internalize it.
| Model (GQA assumption) | L | K·d (per head·heads) | KV bytes / token | KV at 8K context |
|---|---|---|---|---|
| Llama-3-8B | 32 | 8·128 | 128 KB | 1.0 GB |
| Llama-3-70B | 80 | 8·128 | 320 KB | 2.6 GB |
| Mistral-7B (GQA-8) | 32 | 8·128 | 128 KB | 1.0 GB |
| DeepSeek-V3 (MLA latent) | 61 | compressed latent ~576 | ~70 KB | ~0.55 GB |
Arithmetic intensity per kernel, per phase
The roofline classification of each kernel changes with B and T. The table below uses the H100 break-even of ~295 FLOP/byte from lesson 01. A kernel under that ratio is bandwidth-bound; above, compute-bound.
| Kernel | Prefill (B=1, T_prompt=4096) | Decode (B=64, T_hist=4096) | Why the gap |
|---|---|---|---|
| QKV / Out / MLP GEMM | ~T·B FLOP per byte of weight ≈ 4096 → compute-bound | ~B FLOP per byte of weight ≈ 64 → bandwidth-bound on weights | Decode activations are short; the weight has to be re-read to multiply almost nothing. |
| Attention (FA-style) | ~T FLOP per byte of QKV ≈ 4096 → compute-bound | ~T_hist FLOP per byte of KV ≈ small → bandwidth-bound on KV | Decode reads the full history KV but only emits one row of scores. |
| RMSNorm / RoPE / residual | tiny — launch-bound | tiny — launch-bound | Few FLOPs per byte; only worth fusing. |
| Sampling / logits | once per prefill, irrelevant | once per token, dozens of small launches → launch-bound | Top-k/p, masks, RNG, constraint checks each launch. |
The story in two lines:
- Prefill = weights are amortized over T tokens, so GEMMs are compute-bound. The bottleneck is tensor cores. Mat-mul libraries dominate.
- Decode = each step processes one token per request. Weights and KV are read once per step but used by almost no FLOPs. The bottleneck is HBM bandwidth.
Where does a decode step's time actually go?
Take Llama-3-70B at B=32, T=4096, on H100 SXM (HBM ~3.35 TB/s). Weight bytes per layer for the attention+MLP portion: roughly 2·D·(D+2Kd) + 2·D·D + 2·2·D·F + 2·F·D. Plugging D=8192, K=8, d=128, F=28672 gives ~1.71 GB of weights per layer, and ~137 GB across 80 layers — which matches the known ~140 GB total for a bf16 70B model. The full weight set never fits or runs on one H100, so the standard deployment is tensor-parallel-8: each rank holds ~17 GB of weights and reads them every token. KV reads on each rank are roughly 2·1·d·b · T · B · L (one KV head per rank under TP-8), ≈ 6.7 GB at B=32, T=4096. Per-rank HBM traffic per token: ~24 GB. At 3.35 TB/s that is ~7 ms per token — an upper bound on tokens/sec/replica ≈ 140/s before any inefficiency. Real systems land below that because of launch overhead, fragmentation, and imperfect kernels.
Interactive · decode bandwidth accountant
Choose model shape and batch. The widget computes per-token HBM traffic and a theoretical decode rate from your chosen bandwidth. Then it shows the share that is weights vs KV — which determines whether the lever is quantization (weights) or paged-KV/prefix-reuse (KV).
Two phases, two kernel preferences
Because the same kernels move from compute- to bandwidth-bound between phases, serving engines do not pick "an attention kernel" — they pick a family of kernels and route by phase and shape. Specifically:
- Prefill kernels want tiled GEMM-like attention that keeps tensor cores busy and amortizes a single KV scan across many query rows.
- Decode kernels want low launch overhead, ragged batch support, and KV reads that follow indirection without losing coalescing.
- Mixed-phase kernels (introduced in lesson 06 under chunked prefill) need to handle "some sequences are extending, some are decoding" in one launch.
What this gives you for the rest of the track
From now on, every optimization will target a specific row in the per-kernel table above:
- Lesson 03 attacks the attention row with FlashAttention (math) and with prefill/decode-specialized kernels (shape).
- Lesson 04 attacks the KV bytes term with paged storage (allocation efficiency, not fewer bytes).
- Lesson 05 attacks the KV bytes term again by skipping work entirely when prefixes repeat.
- Lesson 06 attacks launch overhead across the chain with batching and graph capture.
- Lesson 07 attacks the weight bytes term with quantization, and tackles the small launches around sampling.