all_lessons / Triton kernels / lessons / 13 · autotune + pipeline lesson 13 / 14

Autotune, num_stages, and pitfalls

You can write a Triton kernel. Now we make it fast. Three knobs the autotuner sweeps — tile sizes, num_warps, num_stages — and the pitfall checklist that explains why the same kernel is fast in one shape and slow in the next.

What the autotuner actually does

@triton.autotune(
    configs=[
        triton.Config({'BM':128,'BN':128,'BK':32}, num_warps=4, num_stages=3),
        triton.Config({'BM':128,'BN':256,'BK':32}, num_warps=8, num_stages=3),
        triton.Config({'BM':64, 'BN':64, 'BK':64}, num_warps=4, num_stages=4),
    ],
    key=['M', 'N', 'K'],     # ← the cache key
)
@triton.jit
def my_kernel(...): ...

On the first call with a new (M, N, K) tuple:

  1. Triton compiles every config (in parallel where possible).
  2. Runs each one a few times to find the fastest.
  3. Caches the winner under the cache key.
  4. Subsequent calls with matching (M, N, K) skip the search — instant lookup.

Picking the key is the most important design choice. Two failure modes:

Key too narrowKey too wide
You key on M only but perf also depends on N. The autotuner caches a config that's good for one (M, N) but slow for another with the same M.You key on M, N, K, dtype, batch, num_heads. Every distinct input hits a fresh autotune sweep, paying ~500ms on the first call.

Rule of thumb: key on the dimensions that change the optimal tile size. For a matmul, that's (M, N, K). For a softmax kernel, just (N) — the number of rows doesn't affect the per-row config.

num_warps — warps per program

The number of 32-lane warps the compiler will use to execute your tile. Effect:

Common values: 4 or 8. For very large tiles (BM·BN ≥ 32K) try 16. Triton enforces num_warps · 32 ≤ 1024 (max threads per block).

num_stages — the software pipeline depth

This is the knob most people don't understand. It controls how deeply Triton overlaps HBM loads with compute. With num_stages = N, the compiler generates code that issues the next N − 1 iterations' loads before the current iteration's compute, so by the time the current compute is done, the next data is already in shared memory.

num_stages = 1 (no pipelining) — serial load + compute load 0 mma 0 load 1 mma 1 load 2 mma 2 num_stages = 3 (typical for tl.dot) — 2 loads issued ahead of compute load 0 load 1 load 2 mma 0 mma 1 mma 2 Compute is hidden behind load latency. The tile lives in 3 SMEM buffers, rotated. Cost: 3× SMEM per operand (or per pair). Too high and you blow occupancy. When to crank num_stages: — On Ampere (A100): num_stages = 3 is standard for matmul. — On Hopper (H100): num_stages = 4 or 5; TMA loads have higher latency. — Memory-bound kernels (vector add, RMSNorm): num_stages = 1 or 2, no extra benefit.
The shorthand
num_stages is "how many K-loop iterations Triton should overlap." Use 3 on Ampere for matmul; 4 on Hopper. Use 1–2 for memory-bound kernels with no dot product.

The full pitfall checklist

Eight things that go wrong in Triton kernels and how to spot each:

PitfallSymptomFix
1 · No boundary maskWrong output or segfault at non-power-of-two shapes. Power-of-two tests pass.Always pass mask= and other= for non-multiples of BLOCK.
2 · Wrong identity for otherBoundary lanes contribute spurious values to a reduction.0 for sum, 1 for product, −∞ for max, +∞ for min.
3 · Stride order swappedGarbage output for transposed views. Contiguous tensors are fine.Pass x.stride(0), x.stride(1) in the same order the kernel multiplies them.
4 · bf16 accumulationVisible loss drift; large-K matmul drifts away from reference after many steps.Always tl.zeros((BM,BN), dtype=tl.float32) for the accumulator.
5 · Missing tl.constexprTile sizes become runtime values; tl.arange can't unroll; ~5× slower.Mark BM, BN, BK, dtypes as tl.constexpr.
6 · Register spillncu shows "No Eligible" stalls; achieved occupancy < 25%.Smaller tile, fewer accumulators, or set maxnreg. A 128×128 fp32 acc is 64KB of registers per program — fits on H100 with effort.
7 · Cross-program atomicsYou wrote tl.atomic_add across many programs. Determinism gone; performance fragile.Two-pass scheme: each program writes a partial result; a second kernel combines.
8 · Stale autotune cacheYou changed the kernel but it's behaving like the old version.rm -rf ~/.triton/cache. Or bump a constant in the function signature.

Register spill — the silent killer

Each thread has a fixed pool of registers (~256 on H100). If your kernel exceeds that, the compiler spills excess values to local memory — which lives in HBM. A spilled register is a 1000× slowdown vs a real one.

Signs you've spilled:

Cures, in order:

  1. Smaller accumulator tile (BM × BN). A 128 × 128 fp32 acc is 64 KB; drop to 128 × 64 = 32 KB.
  2. Fewer registers requested: @triton.jit(num_warps=8) — more warps spread the same registers across more threads.
  3. num_stages = 2 instead of 3 — fewer in-flight buffers.
  4. If all else fails, the kernel is too big. Split it.

Bank conflicts (rare, but real)

Shared memory is divided into 32 banks. When multiple lanes in a warp read from the same bank, the reads serialise. For tl.dot, Triton picks swizzle patterns automatically to avoid this. For custom SMEM access patterns (rare in Triton — you don't allocate SMEM yourself), bank conflicts can show up. The cure is to pad strides so consecutive lanes hit consecutive banks.

Reading ncu output, the 30-second version

Nsight Compute (ncu) is the deepest tool. The two metrics you care about for Triton:

One-liner:

ncu --set roofline --target-processes all python my_script.py

That produces a roofline showing where each kernel sits. Lessons 20–21 of the GPU Kernels series go deeper.

Interactive · which config wins?

Set the workload type and watch the autotuner-style ranking. The "right" tile depends on whether the kernel is compute- or bandwidth-bound.

Config ranking simulator

Compute-bound kernels (large square matmul) like big tiles + many warps. Bandwidth-bound (RMSNorm, small N) like medium tiles + fewer stages.

What's next

You can make a kernel fast. Lesson 14 makes it production-ready: backward passes via torch.autograd.Function, profiling with Nsight, dumping the IR to confirm tensor cores are firing, and the decision tree for when to write Triton vs CUDA vs torch.compile.