Custom kernels & fusion
Every kernel reads from HBM and writes to HBM. If two ops can share one HBM round-trip, you've doubled the effective bandwidth. The whole point of "custom kernels" is to find those pairs.
The kernel-call cost, expanded
From lesson 13 we know the framework pays a few microseconds of Python overhead per op. The kernel itself adds its own fixed costs on the GPU:
- Launch latency. Submitting work to the GPU takes ~1–2 μs on the device side, separate from the host-side Python tax. CUDA Graphs (lesson 19) is the cure.
- Input read. The kernel must read every input element from HBM into SRAM/registers.
- The compute. The actual arithmetic.
- Output write. The result goes back to HBM.
For elementwise ops (add, ReLU, GELU, scaling), the compute is trivial: a few flops per element. Almost all the time is the HBM read + write. That's the memory-bound regime of the roofline (lesson 01). An elementwise op on a 1 GB tensor at H100 HBM bandwidth 3.35 TB/s takes ~0.3 ms just for input + output traffic.
Where fusion comes from
Now consider three ops in a row — say y = (x + bias) * scale; z = relu(y). Eager mode runs this as three kernels:
kernel 1: read x, read bias → write (x + bias) ─ HBM round-trip 1
kernel 2: read (x + bias), read scale → write y ─ HBM round-trip 2
kernel 3: read y → write z ─ HBM round-trip 3
For an N-element tensor, that's 3 · 2 · N · dtype_bytes of HBM traffic. But the actual arithmetic only needs N reads of x, bias, scale, and one write of z. The intermediate buffers are pure overhead — written to HBM only to be re-read by the next kernel.
A fused kernel does all three ops in one pass: read each input element once, do the whole compute chain in registers, write the final result. HBM traffic drops by ~3×; on memory-bound ops, wall-clock drops by ~3×. The "compute" in the middle was already free.
Animated · HBM round-trip, unfused vs fused
Both timelines compute the same thing: y = sigmoid(softmax(x) · W). Top: three separate kernels, each pays its own HBM round-trip on the intermediate. Bottom: one fused kernel reads x once, keeps the intermediate in registers, writes y once. The "HBM bytes" counter on the right adds up as each kernel runs — watch the unfused total balloon.
Three flavours of fusion
| Pattern | What's fused | Win | Example |
|---|---|---|---|
| Elementwise chain | N elementwise ops in a row | ~N× HBM traffic reduction | add → mul → ReLU |
| Matmul + elementwise epilogue | GEMM + bias + activation + dropout | ~2× HBM (the activation tensor is never written to HBM) | cuBLAS epilogues, aten::addmm + ReLU |
| Reduction + elementwise | LayerNorm = mean + var + normalize + scale + bias | ~5× HBM reduction | fused LayerNorm kernels |
The most-fused operator in modern training is FlashAttention, which we touch on next.
FlashAttention as a fusion masterclass
Attention is the canonical case where fusion matters most. Standard ("vanilla") attention is three kernels:
- S = Q K^T / √d_k — produces an N × N matrix; writes O(N²) bytes to HBM.
- P = softmax(S) — reads and writes the N × N matrix.
- O = P V — reads P and V, writes O.
For N = 8192 in bf16, the N × N matrix is 128 MB per head per batch element. Across 64 heads and batch 4, that's 32 GB of HBM round-tripping per attention layer. At 80 layers, attention alone reads/writes ~2.5 TB of HBM per training step. Most of which is the materialised attention matrix that we don't actually need to keep.
FlashAttention (Dao et al. 2022) fuses all three ops by exploiting an algebraic identity: the softmax can be computed online as S V is being formed, tile-by-tile, never materialising the N × N intermediate. The kernel:
- Tiles Q into row blocks and K, V into column blocks.
- For each Q block: load it into SRAM. For each K, V tile in turn: load it into SRAM, compute the partial attention contribution, accumulate into a running output buffer in SRAM (plus a running max and normaliser for the online softmax).
- At the end, write the final output row block back to HBM. The N × N matrix never appears in HBM.
HBM traffic per attention: O(N · d) instead of O(N²). For N = 8192, d = 128, that's N · d / N² = d / N = 1.6% of the original — a 60× HBM reduction. Wall-clock speedup is 2–4× because attention isn't quite memory-bound (the matmul has high arithmetic intensity), but the freed bandwidth lets other parts of the model breathe too.
For more depth: see vLLM lesson 03, which derives the online-softmax recurrence in detail and walks the tile schedule. The reason FlashAttention is a fusion case study and not just a "kernel" is that it's only a win in fusion terms — the underlying arithmetic is identical to vanilla attention.
2D · FlashAttention tile schedule
The Q rows are tiled (rows along the left), K/V columns are tiled (columns along the top). One Q tile sweeps across all K/V column tiles; per tile, the kernel computes the local attention contribution, updates a running max m and normaliser ℓ, and accumulates into the output buffer in SRAM. The N×N attention matrix is never materialised in HBM.
What fusion gives back at backward time
A fused kernel that doesn't materialise its intermediates also can't store them for backward — backward will need to recompute them. This is the same tradeoff as activation checkpointing (lesson 05): forward saves HBM traffic, backward pays in extra FLOPs.
FlashAttention's backward indeed recomputes the attention block-by-block on the fly during backward. Net wall-clock effect: ~30% faster total (forward + backward) compared to vanilla, because the FLOP overhead is small compared to the HBM saving.
When fusion doesn't help
- Already-compute-bound ops. A large matmul is on the right side of the roofline; cuBLAS is already saturating tensor cores. Fusing a dropout into it shaves the tail traffic but doesn't double throughput.
- When the intermediate is needed elsewhere. If a tensor is consumed by multiple downstream ops, the compiler has to choose between (a) materialising the intermediate once and reading it from HBM twice, or (b) duplicating the producer into each consumer's fused kernel (recomputing it). Inductor and XLA do both — duplication for cheap producers, materialisation for expensive ones — so "single consumer" is the easy case, not a hard rule.
- When the fused kernel doesn't fit in SRAM. Each SM has tens to hundreds of KB of SRAM. Larger fusions need to spill to HBM, defeating the point. FlashAttention's block sizes are tuned exactly for the H100's 228 KB SMEM-per-SM budget.
Animated · GEMM with elementwise epilogue, fused vs trailing
A GEMM that produces C = A · B, immediately followed by D = GELU(C + bias). Top: GEMM writes C to HBM, then a separate elementwise kernel reads C back, does the activation, writes D. Bottom: cuBLAS-style "epilogue" — the elementwise math is done in the GEMM's epilogue stage on each output tile, before anything is written to HBM. C never appears in HBM.
How custom kernels get written today
- cuBLAS / cuDNN epilogue API. The "I want a matmul + bias + GELU" use case. NVIDIA provides this directly.
- Hand-written CUDA C++. The hardest path, used for performance-critical kernels (FlashAttention, paged attention). Months of work to outperform a tuned library.
- Triton (lesson 17). The DSL that landed somewhere between CUDA and Python — block-level operations, automatic SRAM management, good autotuning.
- torch.compile / Inductor (lesson 18). Generates Triton kernels from PyTorch code. Automatic fusion of elementwise chains, decent at reductions, not great at matmul-class kernels.
- CUTLASS. A C++ template library for matmul-like kernels. Production-grade serving stacks (TensorRT-LLM, vLLM) use this for custom GEMM patterns like multi-LoRA.
Interactive · the fusion accountant
Stack any number of elementwise ops between (or after) a matmul. The widget computes HBM traffic with and without fusion, and shows the speedup ceiling. Add a softmax + matmul + bias combo and watch the unfused traffic blow past the fused traffic.