Autotune, num_stages, and pitfalls
You can write a Triton kernel. Now we make it fast. Three knobs the autotuner sweeps — tile sizes, num_warps, num_stages — and the pitfall checklist that explains why the same kernel is fast in one shape and slow in the next.
What the autotuner actually does
@triton.autotune(
configs=[
triton.Config({'BM':128,'BN':128,'BK':32}, num_warps=4, num_stages=3),
triton.Config({'BM':128,'BN':256,'BK':32}, num_warps=8, num_stages=3),
triton.Config({'BM':64, 'BN':64, 'BK':64}, num_warps=4, num_stages=4),
],
key=['M', 'N', 'K'], # ← the cache key
)
@triton.jit
def my_kernel(...): ...
On the first call with a new (M, N, K) tuple:
- Triton compiles every config (in parallel where possible).
- Runs each one a few times to find the fastest.
- Caches the winner under the cache key.
- Subsequent calls with matching (M, N, K) skip the search — instant lookup.
Picking the key is the most important design choice. Two failure modes:
| Key too narrow | Key too wide |
|---|---|
You key on M only but perf also depends on N. The autotuner caches a config that's good for one (M, N) but slow for another with the same M. | You key on M, N, K, dtype, batch, num_heads. Every distinct input hits a fresh autotune sweep, paying ~500ms on the first call. |
Rule of thumb: key on the dimensions that change the optimal tile size. For a matmul, that's (M, N, K). For a softmax kernel, just (N) — the number of rows doesn't affect the per-row config.
num_warps — warps per program
The number of 32-lane warps the compiler will use to execute your tile. Effect:
- More warps = more parallelism within one tile, more registers used overall, lower per-warp register pressure. Better for compute-bound tiles where you want lots of mma in flight.
- Fewer warps = each warp does more work, fewer cross-warp barriers, more registers per warp (so larger acc tiles without spill). Better for memory-bound tiles where one warp can saturate HBM.
Common values: 4 or 8. For very large tiles (BM·BN ≥ 32K) try 16. Triton enforces num_warps · 32 ≤ 1024 (max threads per block).
num_stages — the software pipeline depth
This is the knob most people don't understand. It controls how deeply Triton overlaps HBM loads with compute. With num_stages = N, the compiler generates code that issues the next N − 1 iterations' loads before the current iteration's compute, so by the time the current compute is done, the next data is already in shared memory.
num_stages is "how many K-loop iterations Triton should overlap." Use 3 on Ampere for matmul; 4 on Hopper. Use 1–2 for memory-bound kernels with no dot product.
The full pitfall checklist
Eight things that go wrong in Triton kernels and how to spot each:
| Pitfall | Symptom | Fix |
|---|---|---|
| 1 · No boundary mask | Wrong output or segfault at non-power-of-two shapes. Power-of-two tests pass. | Always pass mask= and other= for non-multiples of BLOCK. |
2 · Wrong identity for other | Boundary lanes contribute spurious values to a reduction. | 0 for sum, 1 for product, −∞ for max, +∞ for min. |
| 3 · Stride order swapped | Garbage output for transposed views. Contiguous tensors are fine. | Pass x.stride(0), x.stride(1) in the same order the kernel multiplies them. |
| 4 · bf16 accumulation | Visible loss drift; large-K matmul drifts away from reference after many steps. | Always tl.zeros((BM,BN), dtype=tl.float32) for the accumulator. |
5 · Missing tl.constexpr | Tile sizes become runtime values; tl.arange can't unroll; ~5× slower. | Mark BM, BN, BK, dtypes as tl.constexpr. |
| 6 · Register spill | ncu shows "No Eligible" stalls; achieved occupancy < 25%. | Smaller tile, fewer accumulators, or set maxnreg. A 128×128 fp32 acc is 64KB of registers per program — fits on H100 with effort. |
| 7 · Cross-program atomics | You wrote tl.atomic_add across many programs. Determinism gone; performance fragile. | Two-pass scheme: each program writes a partial result; a second kernel combines. |
| 8 · Stale autotune cache | You changed the kernel but it's behaving like the old version. | rm -rf ~/.triton/cache. Or bump a constant in the function signature. |
Register spill — the silent killer
Each thread has a fixed pool of registers (~256 on H100). If your kernel exceeds that, the compiler spills excess values to local memory — which lives in HBM. A spilled register is a 1000× slowdown vs a real one.
Signs you've spilled:
ncuoutput: "Achieved occupancy < 25%" or "Local memory access" non-zero.- Compiler warning: "warning: function uses too much shared data" or "register spilling".
- Wall-clock 5–10× worse than expected.
Cures, in order:
- Smaller accumulator tile (BM × BN). A 128 × 128 fp32 acc is 64 KB; drop to 128 × 64 = 32 KB.
- Fewer registers requested:
@triton.jit(num_warps=8)— more warps spread the same registers across more threads. num_stages = 2instead of 3 — fewer in-flight buffers.- If all else fails, the kernel is too big. Split it.
Bank conflicts (rare, but real)
Shared memory is divided into 32 banks. When multiple lanes in a warp read from the same bank, the reads serialise. For tl.dot, Triton picks swizzle patterns automatically to avoid this. For custom SMEM access patterns (rare in Triton — you don't allocate SMEM yourself), bank conflicts can show up. The cure is to pad strides so consecutive lanes hit consecutive banks.
Reading ncu output, the 30-second version
Nsight Compute (ncu) is the deepest tool. The two metrics you care about for Triton:
- SM Throughput % — fraction of peak the kernel is using. > 80% means tensor cores are firing; < 30% means something is gating them (bandwidth or spill).
- Memory Throughput % — HBM utilisation. If both this and SM Throughput are low, the kernel is launch-bound or has too few programs.
One-liner:
ncu --set roofline --target-processes all python my_script.py
That produces a roofline showing where each kernel sits. Lessons 20–21 of the GPU Kernels series go deeper.
Interactive · which config wins?
Set the workload type and watch the autotuner-style ranking. The "right" tile depends on whether the kernel is compute- or bandwidth-bound.
What's next
You can make a kernel fast. Lesson 14 makes it production-ready: backward passes via torch.autograd.Function, profiling with Nsight, dumping the IR to confirm tensor cores are firing, and the decision tree for when to write Triton vs CUDA vs torch.compile.