all_lessons/gpu_kernel_serving/11 · forward chainlesson 11 / 17

Transformer forward as a kernel chain

Lesson 01 said "count the bytes." Now we do it. A single decoder layer is six or seven recognizable kernels. By the end of this lesson the KV cache stops being a phrase and becomes a number — and you can predict which kernel saturates which resource.

The question this lesson answers

If a forward pass is just a chain of well-known operations, why does serving feel so model-specific? Because each operation hits the roofline (lesson 01) differently depending on whether the model is processing many input tokens at once or generating one new token at a time. The same matrix multiply that is compute-bound in the first regime is bandwidth-bound in the second. The KV cache, which we now define precisely, is the reason that asymmetry exists.

Two names we'll use from here on

Two phases of generation are referenced throughout the rest of the track. Lesson 03 will derive their different kernel needs in detail; here we just name them so the byte accounting lands.

Two minor pieces of model vocabulary also appear below:

One decoder layer, drawn flat

For a standard Llama-style decoder layer, the kernels are: RMSNorm → QKV projection → rope → attention → output projection → residual add → RMSNorm → MLP gate/up → activation → MLP down → residual add. We list them with their byte and FLOP costs for one decode step (one new token) at batch size B for a model with hidden dim D, head dim d, H query heads, K KV heads, and intermediate size F. Weights are bf16 (2 bytes).

RMSNormB·D bytestiny FLOPs QKV proj2D(D+2Kd) B (W)2·B·D·(D+2Kd) FLOPs RoPEB·D Bsmall Attention2·B·K·T·d B (KV)4·B·H·T·d FLOPs Out proj2·D·D B (W)2·B·D·D FLOPs MLP gate+up2·2·D·F B (W)4·B·D·F FLOPs SwiGLUB·F BB·F FLOPs MLP down2·F·D B (W)2·B·F·D FLOPs B=batch, D=hidden, d=head dim, H=q-heads, K=kv-heads, F=MLP intermediate, T=current seq length, (W)=weight bytes Read this row left-to-right: in prefill, B is replaced by B·T_prompt — activations are tall, so GEMMs hit the compute side of the roofline. in decode, B is just the number of running sequences — activations are skinny, so the same GEMMs become weight-bandwidth-bound.

The KV cache, derived

Attention at decode step T reads every previous token's K and V tensors. Each token contributed K · d values to K and the same to V, in 2 bytes each, across every layer. So the bytes per token, per layer, is:

kv_bytes_per_token_per_layer = 2 · K · d · b (where b = 2 for bf16)

Multiply by the number of layers L to get bytes per token across the whole model. For Llama-3-70B with L=80, K=8, d=128, b=2:

2 · 8 · 128 · 2 = 4 KB / token / layer → × 80 layers = 320 KB / token

For an 8K-context request that is 2.56 MB of KV. For 1024 such requests, 2.56 GB of HBM — and the decode kernel must read all of it, every layer, every token. This is the single most important number in serving. Internalize it.

Model (GQA assumption)LK·d (per head·heads)KV bytes / tokenKV at 8K context
Llama-3-8B328·128128 KB1.0 GB
Llama-3-70B808·128320 KB2.6 GB
Mistral-7B (GQA-8)328·128128 KB1.0 GB
DeepSeek-V3 (MLA latent)61compressed latent ~576~70 KB~0.55 GB

Arithmetic intensity per kernel, per phase

The roofline classification of each kernel changes with B and T. The table below uses the H100 break-even of ~295 FLOP/byte from lesson 01. A kernel under that ratio is bandwidth-bound; above, compute-bound.

KernelPrefill (B=1, T_prompt=4096)Decode (B=64, T_hist=4096)Why the gap
QKV / Out / MLP GEMM~T·B FLOP per byte of weight ≈ 4096 → compute-bound~B FLOP per byte of weight ≈ 64 → bandwidth-bound on weightsDecode activations are short; the weight has to be re-read to multiply almost nothing.
Attention (FA-style)~T FLOP per byte of QKV ≈ 4096 → compute-bound~T_hist FLOP per byte of KV ≈ small → bandwidth-bound on KVDecode reads the full history KV but only emits one row of scores.
RMSNorm / RoPE / residualtiny — launch-boundtiny — launch-boundFew FLOPs per byte; only worth fusing.
Sampling / logitsonce per prefill, irrelevantonce per token, dozens of small launches → launch-boundTop-k/p, masks, RNG, constraint checks each launch.

The story in two lines:

Where does a decode step's time actually go?

Take Llama-3-70B at B=32, T=4096, on H100 SXM (HBM ~3.35 TB/s). Weight bytes per layer for the attention+MLP portion: roughly 2·D·(D+2Kd) + 2·D·D + 2·2·D·F + 2·F·D. Plugging D=8192, K=8, d=128, F=28672 gives ~1.71 GB of weights per layer, and ~137 GB across 80 layers — which matches the known ~140 GB total for a bf16 70B model. The full weight set never fits or runs on one H100, so the standard deployment is tensor-parallel-8: each rank holds ~17 GB of weights and reads them every token. KV reads on each rank are roughly 2·1·d·b · T · B · L (one KV head per rank under TP-8), ≈ 6.7 GB at B=32, T=4096. Per-rank HBM traffic per token: ~24 GB. At 3.35 TB/s that is ~7 ms per token — an upper bound on tokens/sec/replica ≈ 140/s before any inefficiency. Real systems land below that because of launch overhead, fragmentation, and imperfect kernels.

Reading this
The point is not the exact number; it is that you can derive a believable upper bound for tokens/sec on your hardware before running anything. Whenever a benchmark surprises you, redo this math first.

Interactive · decode bandwidth accountant

Choose model shape and batch. The widget computes per-token HBM traffic and a theoretical decode rate from your chosen bandwidth. Then it shows the share that is weights vs KV — which determines whether the lever is quantization (weights) or paged-KV/prefix-reuse (KV).

What does a decode token cost?

All numbers are per-token, single-rank (no tensor parallelism), summed over all layers. Bf16 weights, bf16 KV. For TP-N divide weight bytes by ~N and KV bytes by ~min(N, K). Try B=32 vs B=128, T=2K vs T=32K.

Two phases, two kernel preferences

Because the same kernels move from compute- to bandwidth-bound between phases, serving engines do not pick "an attention kernel" — they pick a family of kernels and route by phase and shape. Specifically:

Common confusion
FlashAttention is not "for decode" or "for prefill." It is a tile schedule for the attention math. Whether it is the right choice depends on shape: long Q tile (prefill) → yes; single-row Q (decode) → there exist FA variants but also dedicated decode kernels like FlashDecoding/FlashInfer's batch-decode path. Lesson 03 makes this concrete.

What this gives you for the rest of the track

From now on, every optimization will target a specific row in the per-kernel table above: