Production — backward, profile, when NOT
The kernel works on a unit test. Now it has to ship: have a backward, be observable in a profile, and you know when to use it. This lesson is the production-readiness checklist plus the decision tree for when Triton is the wrong tool.
Wrapping a Triton kernel as a PyTorch op
For inference-only kernels, the function wrapper from earlier lessons is enough. For training, you need both directions, registered with autograd:
import torch, triton, triton.language as tl
# forward kernel
@triton.jit
def rmsnorm_fwd_kernel(...): ...
# backward kernel (also Triton)
@triton.jit
def rmsnorm_bwd_kernel(...): ...
class RMSNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w, eps):
y, rstd = rmsnorm_forward(x, w, eps) # calls the fwd kernel
ctx.save_for_backward(x, w, rstd)
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
x, w, rstd = ctx.saved_tensors
dx, dw = rmsnorm_backward(dy, x, w, rstd) # calls the bwd kernel
return dx, dw, None # one None per non-tensor input
rmsnorm = RMSNormFn.apply # use this in your nn.Module
Three rules:
- Save what you need. The forward usually keeps tensors that the backward will reuse (rstd in this example) to avoid recomputing them.
- Return one
Noneper non-tensor input.backwardmust return gradients for everyforwardarg, in order. Non-differentiable args (likeeps) returnNone. - Don't capture the kernel itself in the graph. The autograd graph sees one node (RMSNormFn). The two Triton kernels under it are opaque.
Backward kernels are not symmetric to forward
Forward RMSNorm: one row reduction. Backward RMSNorm: two row reductions (one over dy·w, one over dy·w·x) plus a fused multiply, plus a column reduction for dw. The backward is usually ~1.5–2× the cost of forward — that's normal.
For matmul-style ops, the backward is two more matmuls (dA and dB). For attention, the backward has its own paper (Flash Attention v2 recomputes S on the fly during the backward pass to avoid storing the full attention matrix). When you implement backward, expect it to be the harder kernel.
Verifying the backward — gradcheck
The single most important test for a Triton-backed autograd function:
x = torch.randn(4, 128, device='cuda', dtype=torch.float64, requires_grad=True)
w = torch.randn(128, device='cuda', dtype=torch.float64, requires_grad=True)
torch.autograd.gradcheck(rmsnorm, (x, w, 1e-6), eps=1e-6, atol=1e-4)
This finite-differences your backward against the forward. Two notes:
- Use
fp64for gradcheck — bf16/fp16 are too imprecise for finite-diff to converge. - If gradcheck fails, the wall-clock test won't catch it either. Don't skip this step.
Profiling — three views you actually need
| Tool | What it tells you | When to use |
|---|---|---|
triton.testing.do_bench | Single-kernel wall-clock with proper warmup & noise floor. | Iterating on tile sizes. |
torch.profiler | The whole training step: which Triton kernels were called, in what order, with stack traces. | "Is my kernel even running, and how often?" |
ncu (Nsight Compute) | SM utilisation, memory throughput, register spill, occupancy. Per-kernel deep dive. | "Why is this kernel slow?" |
The workflow:
do_benchsays it's slow.torch.profilerconfirms it's the kernel you think (and not, say, an extra copy).ncutells you whether you're SM-bound, memory-bound, or stall-bound.
Reading the IR to confirm tensor cores are firing
If you suspect tl.dot isn't lowering to tensor cores, dump the IR:
import os
os.environ['MLIR_ENABLE_DUMP'] = '1' # before importing triton
# run your kernel once
# then grep the dump:
# grep -E 'mma|wgmma' /tmp/triton_dump_*.ttgir | head
If you see mma.sync (Ampere/Ada) or wgmma (Hopper), tensor cores are firing. If you see llvm.fma instead, you're on the vector ALU — go back to lesson 05 and check tile dims.
The decision tree — when NOT to write Triton
Concrete examples of "not Triton"
- Vanilla bf16 matmul. Call
torch.matmul. cuBLAS is 5–15% faster than hand-Triton, and you save a week. - Attention. Call FlashAttention v3 or FlashInfer. Their Hopper-tuned versions beat Triton by 20–40%.
- NCCL collectives. Triton doesn't talk to NVLink primitives.
- A 100-line elementwise chain after a matmul. Try
torch.compilefirst; in 2026 it fuses most of these correctly. - An op that's 0.5% of the step. The lever is too small. Find a bigger one.
Concrete examples of "yes, Triton"
- A novel fused epilogue. e.g. matmul → bias → silu → dropout → residual add (lesson 08).
- Quantised GEMM with a dequant-on-load that no library covers.
- Custom attention variants. ALiBi + sliding window + learned bias, or MLA's compressed KV path.
- Backward of a fused forward. Whatever you fused on the way in needs a matching kernel on the way out.
- Sparse / block-sparse patterns not covered by cuSPARSE.
Maintenance — what you sign up for
Triton kernels are easier than CUDA, but not free. The ongoing costs:
- Triton releases. Each release can change perf characteristics. Re-benchmark on every upgrade.
- New hardware. Hopper → Blackwell will require autotune sweeps; existing configs may not transfer.
- Backward parity. Forward and backward need to stay in sync as the model evolves.
- Numerical drift. If a new model uses a different activation or eps, you may need to re-derive the kernel.
Interactive · production-readiness checklist
Click each box as your kernel passes it. The bottom counter is your "is this safe to ship?" score.
What this gives you
You've followed the whole arc:
- Lessons 01–03: the abstraction (tiles, not threads), the pipeline (Python → PTX), the why.
- Lessons 04–06: the DSL — loads/masks, tl.dot, reductions.
- Lessons 07–11: five canonical kernels, each one a one-step elaboration on the previous.
- Lesson 12: Flash Attention as the synthesis.
- Lessons 13–14: performance and shipping.
You can write a fused kernel, derive its bandwidth model, autotune across shapes, write the backward, gradcheck it, profile it, decide if Triton is the right tool, and ship it. That's the full skillset for kernels in 2026.