Writing Triton kernels
Triton trades CUDA's thread-level control for a per-program tile language with an autotuner. You give up some last-mile control over mma scheduling and warp specialization in exchange for far less code, automatic shape tuning, and easier maintenance. For 95 % of "I need a fused kernel" cases, that trade is the right one.
The question this lesson answers
You profiled (lesson 20), located a bandwidth-bound chain of small ops (lesson 21), and decided fusion is the lever. Now what? You could write CUDA, but you don't need 100 % of peak — you need a fused kernel today that's within 80–90 % of optimal, autotunable across shapes, easy to maintain. That's exactly what Triton is for.
Triton's programming model in one diagram
The Triton DSL, just the parts you need
| Primitive | Purpose | Notes |
|---|---|---|
@triton.jit | JIT-compile this Python function as a kernel. | The function is staged — runs at compile time, emits PTX. |
tl.program_id(axis) | Which tile is this program. | Up to 3 axes. Map your problem decomposition here. |
tl.arange(0, BLOCK) | Vector of compile-time-known length. | The "thread vector." Operations on it are SIMD. |
tl.load(ptr, mask, other) | Coalesced load with predicate. | Out-of-bounds lanes get other. Always pass a mask for the boundary tile. |
tl.store(ptr, val, mask) | Coalesced store with predicate. | Same boundary contract. |
tl.dot(a, b) | Tile matmul. Hits tensor cores when shapes allow. | The right primitive for GEMM-shaped inner loops. |
tl.sum / tl.max / tl.min / tl.cumsum | Reductions across a tile axis. | Compiles to warp shuffles + SMEM. |
tl.exp / tl.log / tl.rsqrt | Math intrinsics (fast hardware versions). | Lower precision than std libm but usually fine. |
tl.constexpr | Marks an argument as compile-time. | Tile sizes, dtypes. Anything that affects code generation. |
@triton.autotune(configs, key=...) | Compile several configs; pick the fastest on first call per key. | The autotuner is your friend. Use it. |
num_warps, num_stages (config knobs) | num_warps: warps per program (occupancy vs work-per-warp). num_stages: depth of the software pipeline Triton emits for overlapping memory loads with compute. | Both are config-level, not per-instruction. The autotuner sweeps them. |
A complete fused kernel: GELU + linear bias
Suppose downstream we want y = gelu(x @ w + b). With eager PyTorch that is 3 launches (matmul → add → gelu). A Triton kernel folds the bias and activation into the GEMM's epilogue.
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
triton.Config({'BM': 128, 'BN': 256, 'BK': 32}, num_warps=8, num_stages=3),
triton.Config({'BM': 64, 'BN': 64, 'BK': 64}, num_warps=4, num_stages=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def gelu_linear(
x_ptr, w_ptr, b_ptr, y_ptr,
M, N, K,
sxm, sxk, swk, swn, sym, syn, # strides
BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BM + tl.arange(0, BM)
offs_n = pid_n * BN + tl.arange(0, BN)
offs_k = tl.arange(0, BK)
acc = tl.zeros((BM, BN), dtype=tl.float32)
# outer-product K-loop with on-chip accumulation
for k in range(0, K, BK):
a = tl.load(x_ptr + offs_m[:, None]*sxm + (k+offs_k)[None, :]*sxk,
mask=(offs_m[:, None] < M) & ((k+offs_k)[None, :] < K),
other=0.).to(tl.bfloat16)
bw = tl.load(w_ptr + (k+offs_k)[:, None]*swk + offs_n[None, :]*swn,
mask=((k+offs_k)[:, None] < K) & (offs_n[None, :] < N),
other=0.).to(tl.bfloat16)
acc += tl.dot(a, bw) # tensor-core matmul on tile
# fused epilogue: bias + GELU
bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.).to(tl.float32)
z = acc + bias[None, :]
# gelu approx: 0.5 * z * (1 + tanh(sqrt(2/pi) * (z + 0.044715 * z^3)))
c = 0.7978845608 # sqrt(2/pi)
y = 0.5 * z * (1.0 + tl.tanh(c * (z + 0.044715 * z*z*z)))
tl.store(y_ptr + offs_m[:, None]*sym + offs_n[None, :]*syn,
y.to(tl.bfloat16),
mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
The wrapper that makes it a PyTorch op
def gelu_linear_fn(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
M, K = x.shape
K2, N = w.shape
assert K == K2
y = torch.empty(M, N, device=x.device, dtype=x.dtype)
grid = lambda meta: (triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
gelu_linear[grid](
x, w, b, y,
M, N, K,
x.stride(0), x.stride(1),
w.stride(0), w.stride(1),
y.stride(0), y.stride(1),
)
return y
For backward, register a torch.autograd.Function with explicit forward and backward Triton kernels — or write a single forward and let torch.compile (lesson 23) handle the backward by tracing through your op.
The design loop
What autotune actually does
The autotuner stores a dict keyed by your key=[...] values. For each new key, it compiles every config, runs each one a few times, picks the fastest, and caches the choice. Subsequent calls with the same key skip the search. Two practical consequences:
- The first call per shape is slow — sometimes hundreds of ms. Run a warmup pass before timing.
- Pick the
keycarefully. If you key only onMbut performance also depends onN, you'll cache a suboptimal config. Conversely, if you key on too many fields, every new shape is a fresh search.
# benchmark a single config
ms = triton.testing.do_bench(lambda: gelu_linear_fn(x, w, b), warmup=25, rep=100)
print(f"{ms:.3f} ms")
Common pitfalls and how to spot them
| Pitfall | Symptom | Fix |
|---|---|---|
| No boundary mask | Out-of-bounds reads → segfault or wrong outputs at irregular shapes. | Always pass mask= and other= for non-multiples of BLOCK. |
| Wrong stride argument order | Garbled output for non-contiguous tensors. | Pass x.stride(0), x.stride(1) in the same order the kernel multiplies. |
| Accumulating in bf16 | Loss drift; large-K matmuls go visibly wrong. | Always accumulate in fp32 (tl.zeros(..., dtype=tl.float32)). |
Forgot tl.constexpr | Tile sizes become runtime values; compiler can't vectorize. | Mark BM, BN, BK, dtypes as tl.constexpr. |
| Register spill | Achieved occupancy <25 %; ncu shows "No Eligible" stalls. | Smaller tile, fewer accumulators, or set maxnreg. |
| SMEM bank conflicts (non-dot patterns) | SM throughput much lower than peak in reductions or custom SMEM scratchpad usage. | Triton's layout planner picks swizzle patterns to avoid bank conflicts in tl.dot automatically. If you see conflicts, they are almost always in custom reductions; pad strides or restructure the access pattern. |
| Autotune never re-runs | Cached config is stale after you change the kernel. | Clear ~/.triton/cache or bump a constant in the function. |
| Reduction across blocks | You wrote tl.atomic_add across many programs and lost determinism. | Use a two-pass scheme: per-block partial → second kernel to combine. |
When NOT to write Triton
- Vanilla GEMM: cuBLAS / cuBLASLt and CUTLASS are usually 5–15 % faster. Use them via
torch.matmulortorch._scaled_mm. - Attention: FlashAttention v3 + FlashInfer have hand-tuned producer-consumer kernels with TMA. Don't reinvent unless you're targeting a feature they don't support.
- Communication kernels: NCCL is the answer. Triton doesn't talk to NVLink primitives directly.
- The kernel is going to be 1 % of the step: not worth the maintenance cost.
Use Triton when (a) no library covers the operation, (b) the operation is bandwidth-bound and fusion saves bytes, or (c) you need autotuning across many shapes and don't have time to hand-tune CUDA per shape.
Interactive · is Triton the right tool?
Set how much of the step the candidate region takes, how unique the op is, and how irregular the shapes are. The widget recommends Triton, a library call, or "don't bother."
What this gives you for the next lesson
You can write a fused kernel by hand. Lesson 23 covers the lazy alternative: let torch.compile generate Triton for you. That's right far more often than people think — and when it's not, you have the skills from this lesson to step in.