all_lessons/gpu_kernel_serving/16 · beyond attentionlesson 16 / 17

Beyond attention: GEMM, quantization, MoE, sampling

Once attention is well-tuned and the scheduler keeps the GPU busy, the bottleneck moves elsewhere in the decode chain. This lesson walks the rest: weight GEMMs and quantization, MoE routing, communication, and the small launches around sampling.

The question this lesson answers

From lesson 02, decode HBM traffic split into "KV" (handled by paged KV + prefix reuse) and "weights." For Llama-3-70B at small batch the weight share dominates: ~34 GB read per layer-summed forward pass. Halving that — by reading 4-bit instead of bf16 — would roughly halve decode latency, but only if a kernel exists that can do the math on the new layout. This lesson is about that kind of move: each remaining win is a "layout × kernel × scheduler" agreement, not just a faster math primitive.

The decode chain, ranked by where time actually goes

For Llama-3-70B at B=32, T=4K on H100 (rough, illustrative):

Linear GEMMs (QKV, Out, MLP up/gate/down) ≈ 60% Attention (KV read + softmax) ≈ 25% norm 8% samp Dense-model decode (Llama-3-70B at B=32, T=4K): the linear layers dominate, attention comes next, the rest is "small but adds up." For MoE models the MLP slot is replaced by routing + grouped GEMM + all-to-all — usually wider than the dense MLP it replaces. Implication Once attention and scheduling are tight, the next dollar of effort goes into the linear GEMMs. For dense models that means quantization; for MoE that also means routing & expert-parallel comms.

Quantization as a layout decision

"Quantize the weights" really means two things at once:

  1. Store each weight in fewer bits (e.g., 4 instead of 16) → less HBM read per token.
  2. Compute the matmul using those compressed weights → requires a kernel that can dequantize on-the-fly into tensor-core math.

The bytes-saved part is the easy half. The kernel part is where 90% of the engineering goes. A 4-bit weight isn't a number a tensor core can multiply directly; the kernel must read packed nibbles, apply per-group scales (and sometimes zero-points), produce bf16/fp16 operands, and feed them to mma instructions. Done well, dequant happens in registers and the throughput is bandwidth-bound on packed bytes; done badly, dequant overhead exceeds the savings.

The quant landscape in one paragraph. GPTQ and AWQ are calibration algorithms that produce 4-bit weight checkpoints with per-group scales. SmoothQuant rescales weights and activations so both can be 8-bit. Marlin and Machete are NVIDIA-tuned int4×bf16 GEMM kernels for those checkpoint formats. cuBLAS and CUTLASS are NVIDIA's reference GEMM libraries (FP16/BF16/FP8). Triton is a kernel-authoring DSL with its own GEMM implementations. The checkpoint format and the kernel have to agree on packing, group size, and scale layout.

FormatWeight bytes / paramKernel families that actually run fastNotes
bf16 / fp162 BcuBLAS, CUTLASS, TritonReference; no quirks.
fp8 (e4m3 / e5m2)1 BCUTLASS fp8 GEMM, cuBLAS, Hopper+Activation also fp8; needs scale management.
int8 W8A81 BSmoothQuant kernels, CUTLASS i8Activation also int8; per-channel scales.
int8 W8A161 BMarlin (W8A16), CUTLASSbf16 activation, int8 weight; per-group scales.
int4 W4A16 (GPTQ, AWQ)0.5 BMarlin, Machete, AWQ kernelsGroup size (often 128) must match kernel; activation stays bf16/fp16.
2-bit / mixed0.25 B + outliersSpecialized; gains diminishTail of distribution often needs higher precision.
When quantization regresses
The kernel's packed-bytes pipeline can become compute-bound on dequant if group size is too small, scale handling is too expensive, or the kernel was tuned for an older arch. Always benchmark against a tuned bf16 baseline; "we quantized to 4-bit" is not by itself a win.

The MoE detour

Mixture-of-experts replaces the dense MLP with K experts and a router that picks the top-r per token. From the kernel side this introduces five new pieces, in order:

  1. Routing kernel: compute gating scores, top-r selection, and per-expert token counts.
  2. Token permutation: reorder tokens so each expert gets a contiguous slab.
  3. Dispatch all-to-all: if experts are sharded across GPUs (expert parallel), each token's hidden state is sent to the rank holding its chosen experts.
  4. Grouped GEMM: one launch that computes K small GEMMs of different shapes, one per expert, on the local slab.
  5. Combine all-to-all + unpermutation: outputs travel back and are restored to original token order.

The all-to-all pair is a serving system problem, not just a kernel one. It saturates NVLink/InfiniBand bandwidth and is the typical bottleneck of MoE decode. Throughput depends on routing balance: skewed routes leave some experts idle while others overflow.

routertop-r/token permutegroup by expert a2a (dispatch)to expert's rank grouped GEMMK shapes, one launch a2a (combine)back to origin rank unpermuterestore order FlashInfer-style libraries fuse router, permute/unpermute, and grouped GEMM. The two all-to-alls pull in NCCL or a custom EP transport — that's the part that scales with GPU count. Balanced routing (~1.0× imbalance) decodes near dense-MoE throughput. At 2× imbalance the slowest rank sets the rate. Practical knobs: expert capacity factor (overflow handling), drop policy, expert-parallel topology.

Sampling: many tiny kernels in a hot loop

Once logits exist, the engine applies (some subset of): temperature scaling, presence/frequency penalties, repetition penalty, top-k, top-p (nucleus), min-p, banned tokens, grammar constraints, random sampling. Each is a kernel; some need a sort or partial-sort; some require RNG state.

The naive implementation launches a dozen small ops per step and pays the launch overhead from lesson 06. Production engines fuse the chain: a single kernel applies temperature, penalties, masks, and top-k/p selection, then a single random sample. vLLM and FlashInfer both expose batched sampling kernels along these lines.

StageWhat it doesNaive costFused cost
TemperatureDivide logits by T.1 kernelfolded into the next op
PenaltiesSubtract by token-frequency vectors.1–2 kernels, gather over historysame op, fused with the mask
Top-kKeep top-k logits; mask others.radix-select or sort kernelwarp-level select, no extra launch
Top-pKeep smallest set summing to ≥ p.sort + cumsumfused within the same pass
SampleRandom draw from filtered distribution.RNG + categoricalone CUDA-side RNG, no host trip

Why this matters: sampling is small per token, but it runs every token. A 200 µs sampling chain at 100 tok/s/replica × 1000 replicas is a meaningful chunk of fleet GPU time.

Communication kernels

Tensor parallelism splits each weight matrix across GPUs. Every all-reduce between layers is a kernel — usually NCCL ring or tree. Two practical knobs:

For MoE, the all-to-all dominates communication; production engines fuse the routing+permute step with the all-to-all so a single kernel does dispatch instead of multiple round trips.

The backend dispatcher, demystified

Every engine ends up with a small matrix: "for op X with model feature set Y on hardware Z, run backend B." This is what the docs call attention backend selection, quantization scheme, or communication backend. The matrix exists because: head dim, KV dtype, page size, sliding-window-ness, MLA-vs-MHA, and GPU generation all constrain which kernel is correct. A "fast but invalid" kernel is not an optimization; it is a bug.

How to read backend docs
The columns are constraints, not preferences. If a backend's row says "fp8 KV: no," that means it would silently produce garbage; the engine refuses to use it for fp8 KV. Pick the row that fits all your features, then sort by speed.

Interactive · quantization payoff accountant

Set weight precision, dequant overhead, and dense decode time. The widget shows the bandwidth savings vs compute cost, and warns if dequant eats the win.

Quantization: bandwidth saved vs dequant cost

A 4-bit weight reads 4× less HBM than bf16. If your kernel's dequant overhead is small (≲ 10% of GEMM time), you keep ~3.5× decode speedup. If overhead is 50%, you keep 2×. If 100%, you've broken even.

How to identify the next bottleneck

A short decision rule for "I made attention faster, now what?"

  1. Profile a decode step with timeline markers. Get per-kernel times.
  2. Pick the largest non-attention bar. If it's a linear GEMM and the model is bf16, try quantization (with a tuned kernel for your hardware).
  3. If it's MoE all-to-all, check routing imbalance and expert-parallel topology.
  4. If it's sampling/logits, fuse the pipeline or move to a batched sampling kernel.
  5. If it's "GPU idle" (CPU gaps), go back to lesson 06 — graphs, batching, or scheduler are the lever.

Notice this is exactly the chain from lesson 02, just walked once for each surviving bar. The track was structured to make this loop short.

What this gives you for the rest of the track

Every kernel-level concept is now in place: how the hardware works, what attention does, where KV lives, when prefixes share, how the scheduler hides launch overhead, and how the rest of the chain compresses. Lesson 08 zooms out by one level and shows the framework around those kernels: HTTP entry, queueing, routing, and where end-to-end latency really comes from.