Beyond attention: GEMM, quantization, MoE, sampling
Once attention is well-tuned and the scheduler keeps the GPU busy, the bottleneck moves elsewhere in the decode chain. This lesson walks the rest: weight GEMMs and quantization, MoE routing, communication, and the small launches around sampling.
The question this lesson answers
From lesson 02, decode HBM traffic split into "KV" (handled by paged KV + prefix reuse) and "weights." For Llama-3-70B at small batch the weight share dominates: ~34 GB read per layer-summed forward pass. Halving that — by reading 4-bit instead of bf16 — would roughly halve decode latency, but only if a kernel exists that can do the math on the new layout. This lesson is about that kind of move: each remaining win is a "layout × kernel × scheduler" agreement, not just a faster math primitive.
The decode chain, ranked by where time actually goes
For Llama-3-70B at B=32, T=4K on H100 (rough, illustrative):
Quantization as a layout decision
"Quantize the weights" really means two things at once:
- Store each weight in fewer bits (e.g., 4 instead of 16) → less HBM read per token.
- Compute the matmul using those compressed weights → requires a kernel that can dequantize on-the-fly into tensor-core math.
The bytes-saved part is the easy half. The kernel part is where 90% of the engineering goes. A 4-bit weight isn't a number a tensor core can multiply directly; the kernel must read packed nibbles, apply per-group scales (and sometimes zero-points), produce bf16/fp16 operands, and feed them to mma instructions. Done well, dequant happens in registers and the throughput is bandwidth-bound on packed bytes; done badly, dequant overhead exceeds the savings.
The quant landscape in one paragraph. GPTQ and AWQ are calibration algorithms that produce 4-bit weight checkpoints with per-group scales. SmoothQuant rescales weights and activations so both can be 8-bit. Marlin and Machete are NVIDIA-tuned int4×bf16 GEMM kernels for those checkpoint formats. cuBLAS and CUTLASS are NVIDIA's reference GEMM libraries (FP16/BF16/FP8). Triton is a kernel-authoring DSL with its own GEMM implementations. The checkpoint format and the kernel have to agree on packing, group size, and scale layout.
| Format | Weight bytes / param | Kernel families that actually run fast | Notes |
|---|---|---|---|
| bf16 / fp16 | 2 B | cuBLAS, CUTLASS, Triton | Reference; no quirks. |
| fp8 (e4m3 / e5m2) | 1 B | CUTLASS fp8 GEMM, cuBLAS, Hopper+ | Activation also fp8; needs scale management. |
| int8 W8A8 | 1 B | SmoothQuant kernels, CUTLASS i8 | Activation also int8; per-channel scales. |
| int8 W8A16 | 1 B | Marlin (W8A16), CUTLASS | bf16 activation, int8 weight; per-group scales. |
| int4 W4A16 (GPTQ, AWQ) | 0.5 B | Marlin, Machete, AWQ kernels | Group size (often 128) must match kernel; activation stays bf16/fp16. |
| 2-bit / mixed | 0.25 B + outliers | Specialized; gains diminish | Tail of distribution often needs higher precision. |
The MoE detour
Mixture-of-experts replaces the dense MLP with K experts and a router that picks the top-r per token. From the kernel side this introduces five new pieces, in order:
- Routing kernel: compute gating scores, top-r selection, and per-expert token counts.
- Token permutation: reorder tokens so each expert gets a contiguous slab.
- Dispatch all-to-all: if experts are sharded across GPUs (expert parallel), each token's hidden state is sent to the rank holding its chosen experts.
- Grouped GEMM: one launch that computes K small GEMMs of different shapes, one per expert, on the local slab.
- Combine all-to-all + unpermutation: outputs travel back and are restored to original token order.
The all-to-all pair is a serving system problem, not just a kernel one. It saturates NVLink/InfiniBand bandwidth and is the typical bottleneck of MoE decode. Throughput depends on routing balance: skewed routes leave some experts idle while others overflow.
Sampling: many tiny kernels in a hot loop
Once logits exist, the engine applies (some subset of): temperature scaling, presence/frequency penalties, repetition penalty, top-k, top-p (nucleus), min-p, banned tokens, grammar constraints, random sampling. Each is a kernel; some need a sort or partial-sort; some require RNG state.
The naive implementation launches a dozen small ops per step and pays the launch overhead from lesson 06. Production engines fuse the chain: a single kernel applies temperature, penalties, masks, and top-k/p selection, then a single random sample. vLLM and FlashInfer both expose batched sampling kernels along these lines.
| Stage | What it does | Naive cost | Fused cost |
|---|---|---|---|
| Temperature | Divide logits by T. | 1 kernel | folded into the next op |
| Penalties | Subtract by token-frequency vectors. | 1–2 kernels, gather over history | same op, fused with the mask |
| Top-k | Keep top-k logits; mask others. | radix-select or sort kernel | warp-level select, no extra launch |
| Top-p | Keep smallest set summing to ≥ p. | sort + cumsum | fused within the same pass |
| Sample | Random draw from filtered distribution. | RNG + categorical | one CUDA-side RNG, no host trip |
Why this matters: sampling is small per token, but it runs every token. A 200 µs sampling chain at 100 tok/s/replica × 1000 replicas is a meaningful chunk of fleet GPU time.
Communication kernels
Tensor parallelism splits each weight matrix across GPUs. Every all-reduce between layers is a kernel — usually NCCL ring or tree. Two practical knobs:
- Overlap: launch the next layer's GEMM while the previous layer's all-reduce is in flight.
- Reduce-scatter + all-gather pairs: split a single all-reduce into two halves so different parts of the model can pipeline more aggressively.
For MoE, the all-to-all dominates communication; production engines fuse the routing+permute step with the all-to-all so a single kernel does dispatch instead of multiple round trips.
The backend dispatcher, demystified
Every engine ends up with a small matrix: "for op X with model feature set Y on hardware Z, run backend B." This is what the docs call attention backend selection, quantization scheme, or communication backend. The matrix exists because: head dim, KV dtype, page size, sliding-window-ness, MLA-vs-MHA, and GPU generation all constrain which kernel is correct. A "fast but invalid" kernel is not an optimization; it is a bug.
Interactive · quantization payoff accountant
Set weight precision, dequant overhead, and dense decode time. The widget shows the bandwidth savings vs compute cost, and warns if dequant eats the win.
How to identify the next bottleneck
A short decision rule for "I made attention faster, now what?"
- Profile a decode step with timeline markers. Get per-kernel times.
- Pick the largest non-attention bar. If it's a linear GEMM and the model is bf16, try quantization (with a tuned kernel for your hardware).
- If it's MoE all-to-all, check routing imbalance and expert-parallel topology.
- If it's sampling/logits, fuse the pipeline or move to a batched sampling kernel.
- If it's "GPU idle" (CPU gaps), go back to lesson 06 — graphs, batching, or scheduler are the lever.
Notice this is exactly the chain from lesson 02, just walked once for each surviving bar. The track was structured to make this loop short.
What this gives you for the rest of the track
Every kernel-level concept is now in place: how the hardware works, what attention does, where KV lives, when prefixes share, how the scheduler hides launch overhead, and how the rest of the chain compresses. Lesson 08 zooms out by one level and shows the framework around those kernels: HTTP entry, queueing, routing, and where end-to-end latency really comes from.