system_ml / 16 · kernels & fusion lesson 16 / 19

Custom kernels & fusion

Every kernel reads from HBM and writes to HBM. If two ops can share one HBM round-trip, you've doubled the effective bandwidth. The whole point of "custom kernels" is to find those pairs.

The kernel-call cost, expanded

From lesson 13 we know the framework pays a few microseconds of Python overhead per op. The kernel itself adds its own fixed costs on the GPU:

For elementwise ops (add, ReLU, GELU, scaling), the compute is trivial: a few flops per element. Almost all the time is the HBM read + write. That's the memory-bound regime of the roofline (lesson 01). An elementwise op on a 1 GB tensor at H100 HBM bandwidth 3.35 TB/s takes ~0.3 ms just for input + output traffic.

Where fusion comes from

Now consider three ops in a row — say y = (x + bias) * scale; z = relu(y). Eager mode runs this as three kernels:

kernel 1:  read x, read bias                      → write (x + bias)        ─ HBM round-trip 1
kernel 2:  read (x + bias), read scale            → write y                  ─ HBM round-trip 2
kernel 3:  read y                                 → write z                  ─ HBM round-trip 3

For an N-element tensor, that's 3 · 2 · N · dtype_bytes of HBM traffic. But the actual arithmetic only needs N reads of x, bias, scale, and one write of z. The intermediate buffers are pure overhead — written to HBM only to be re-read by the next kernel.

A fused kernel does all three ops in one pass: read each input element once, do the whole compute chain in registers, write the final result. HBM traffic drops by ~3×; on memory-bound ops, wall-clock drops by ~3×. The "compute" in the middle was already free.

Animated · HBM round-trip, unfused vs fused

Both timelines compute the same thing: y = sigmoid(softmax(x) · W). Top: three separate kernels, each pays its own HBM round-trip on the intermediate. Bottom: one fused kernel reads x once, keeps the intermediate in registers, writes y once. The "HBM bytes" counter on the right adds up as each kernel runs — watch the unfused total balloon.

HBM round-trip · unfused vs fused, same math
Each kernel's bar shows its phases: read inputs from HBM, compute, write output to HBM. Fused merges the three into one read + one write; intermediates live in registers/SRAM.
unfused HBM
fused HBM
HBM reduction
speedup
The slogan for the rest of this lesson
Memory-bound ops are bottlenecked by HBM. Fusing them removes HBM round-trips, which is the only knob that matters in that regime. Compute-bound ops are bottlenecked by FLOPs — fusion doesn't help (you'd still do the same arithmetic). Most fusion wins come from fusing memory-bound epilogues onto a compute-bound matmul.

Three flavours of fusion

PatternWhat's fusedWinExample
Elementwise chainN elementwise ops in a row~N× HBM traffic reductionadd → mul → ReLU
Matmul + elementwise epilogueGEMM + bias + activation + dropout~2× HBM (the activation tensor is never written to HBM)cuBLAS epilogues, aten::addmm + ReLU
Reduction + elementwiseLayerNorm = mean + var + normalize + scale + bias~5× HBM reductionfused LayerNorm kernels

The most-fused operator in modern training is FlashAttention, which we touch on next.

FlashAttention as a fusion masterclass

Attention is the canonical case where fusion matters most. Standard ("vanilla") attention is three kernels:

  1. S = Q K^T / √d_k — produces an N × N matrix; writes O(N²) bytes to HBM.
  2. P = softmax(S) — reads and writes the N × N matrix.
  3. O = P V — reads P and V, writes O.

For N = 8192 in bf16, the N × N matrix is 128 MB per head per batch element. Across 64 heads and batch 4, that's 32 GB of HBM round-tripping per attention layer. At 80 layers, attention alone reads/writes ~2.5 TB of HBM per training step. Most of which is the materialised attention matrix that we don't actually need to keep.

FlashAttention (Dao et al. 2022) fuses all three ops by exploiting an algebraic identity: the softmax can be computed online as S V is being formed, tile-by-tile, never materialising the N × N intermediate. The kernel:

  1. Tiles Q into row blocks and K, V into column blocks.
  2. For each Q block: load it into SRAM. For each K, V tile in turn: load it into SRAM, compute the partial attention contribution, accumulate into a running output buffer in SRAM (plus a running max and normaliser for the online softmax).
  3. At the end, write the final output row block back to HBM. The N × N matrix never appears in HBM.

HBM traffic per attention: O(N · d) instead of O(N²). For N = 8192, d = 128, that's N · d / N² = d / N = 1.6% of the original — a 60× HBM reduction. Wall-clock speedup is 2–4× because attention isn't quite memory-bound (the matmul has high arithmetic intensity), but the freed bandwidth lets other parts of the model breathe too.

For more depth: see vLLM lesson 03, which derives the online-softmax recurrence in detail and walks the tile schedule. The reason FlashAttention is a fusion case study and not just a "kernel" is that it's only a win in fusion terms — the underlying arithmetic is identical to vanilla attention.

2D · FlashAttention tile schedule

The Q rows are tiled (rows along the left), K/V columns are tiled (columns along the top). One Q tile sweeps across all K/V column tiles; per tile, the kernel computes the local attention contribution, updates a running max m and normaliser , and accumulates into the output buffer in SRAM. The N×N attention matrix is never materialised in HBM.

FlashAttention · tile-by-tile sweep with online softmax
Yellow cell = active (Q_i, K_j) tile being computed. Right panel: the running softmax state (m, ℓ) for the current Q row tile, updating as more K columns are folded in.
Q tile (row)
K tile (col)
running m (max)
running ℓ (norm)

What fusion gives back at backward time

A fused kernel that doesn't materialise its intermediates also can't store them for backward — backward will need to recompute them. This is the same tradeoff as activation checkpointing (lesson 05): forward saves HBM traffic, backward pays in extra FLOPs.

FlashAttention's backward indeed recomputes the attention block-by-block on the fly during backward. Net wall-clock effect: ~30% faster total (forward + backward) compared to vanilla, because the FLOP overhead is small compared to the HBM saving.

When fusion doesn't help

Animated · GEMM with elementwise epilogue, fused vs trailing

A GEMM that produces C = A · B, immediately followed by D = GELU(C + bias). Top: GEMM writes C to HBM, then a separate elementwise kernel reads C back, does the activation, writes D. Bottom: cuBLAS-style "epilogue" — the elementwise math is done in the GEMM's epilogue stage on each output tile, before anything is written to HBM. C never appears in HBM.

Epilogue fusion · the GEMM saves one whole HBM round-trip
Animated: GEMM tiles complete one by one. Top track writes C to HBM and reads it back (orange dashed = wasted round-trip). Bottom track applies GELU+bias inside the tile's epilogue, writes D directly.
unfused HBM
fused HBM
C tensor written
speedup (mem-bound part)

How custom kernels get written today

  1. cuBLAS / cuDNN epilogue API. The "I want a matmul + bias + GELU" use case. NVIDIA provides this directly.
  2. Hand-written CUDA C++. The hardest path, used for performance-critical kernels (FlashAttention, paged attention). Months of work to outperform a tuned library.
  3. Triton (lesson 17). The DSL that landed somewhere between CUDA and Python — block-level operations, automatic SRAM management, good autotuning.
  4. torch.compile / Inductor (lesson 18). Generates Triton kernels from PyTorch code. Automatic fusion of elementwise chains, decent at reductions, not great at matmul-class kernels.
  5. CUTLASS. A C++ template library for matmul-like kernels. Production-grade serving stacks (TensorRT-LLM, vLLM) use this for custom GEMM patterns like multi-LoRA.

Interactive · the fusion accountant

Stack any number of elementwise ops between (or after) a matmul. The widget computes HBM traffic with and without fusion, and shows the speedup ceiling. Add a softmax + matmul + bias combo and watch the unfused traffic blow past the fused traffic.

HBM traffic · unfused vs fused
Each op reads its inputs (R) and writes its output (W). Fused: all ops share one R/W. Unfused: each op pays separately.
unfused HBM
fused HBM
unfused time
fused speedup
Takeaway
Fusion is HBM-traffic accounting. Memory-bound op chains benefit linearly with the number of ops fused. Compute-bound ops barely care. FlashAttention is the headline example because it changes attention from O(N²) HBM traffic to O(N · d). The next two lessons (Triton, torch.compile) are the tools that let mortals write fused kernels without spending a year on CUDA C++.