all_lessons / Triton kernels / lessons / 14 · production lesson 14 / 14

Production — backward, profile, when NOT

The kernel works on a unit test. Now it has to ship: have a backward, be observable in a profile, and you know when to use it. This lesson is the production-readiness checklist plus the decision tree for when Triton is the wrong tool.

Wrapping a Triton kernel as a PyTorch op

For inference-only kernels, the function wrapper from earlier lessons is enough. For training, you need both directions, registered with autograd:

import torch, triton, triton.language as tl

# forward kernel
@triton.jit
def rmsnorm_fwd_kernel(...): ...

# backward kernel (also Triton)
@triton.jit
def rmsnorm_bwd_kernel(...): ...

class RMSNormFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, w, eps):
        y, rstd = rmsnorm_forward(x, w, eps)        # calls the fwd kernel
        ctx.save_for_backward(x, w, rstd)
        ctx.eps = eps
        return y

    @staticmethod
    def backward(ctx, dy):
        x, w, rstd = ctx.saved_tensors
        dx, dw = rmsnorm_backward(dy, x, w, rstd)    # calls the bwd kernel
        return dx, dw, None                           # one None per non-tensor input

rmsnorm = RMSNormFn.apply        # use this in your nn.Module

Three rules:

  1. Save what you need. The forward usually keeps tensors that the backward will reuse (rstd in this example) to avoid recomputing them.
  2. Return one None per non-tensor input. backward must return gradients for every forward arg, in order. Non-differentiable args (like eps) return None.
  3. Don't capture the kernel itself in the graph. The autograd graph sees one node (RMSNormFn). The two Triton kernels under it are opaque.

Backward kernels are not symmetric to forward

Forward RMSNorm: one row reduction. Backward RMSNorm: two row reductions (one over dy·w, one over dy·w·x) plus a fused multiply, plus a column reduction for dw. The backward is usually ~1.5–2× the cost of forward — that's normal.

For matmul-style ops, the backward is two more matmuls (dA and dB). For attention, the backward has its own paper (Flash Attention v2 recomputes S on the fly during the backward pass to avoid storing the full attention matrix). When you implement backward, expect it to be the harder kernel.

Verifying the backward — gradcheck

The single most important test for a Triton-backed autograd function:

x = torch.randn(4, 128, device='cuda', dtype=torch.float64, requires_grad=True)
w = torch.randn(128, device='cuda', dtype=torch.float64, requires_grad=True)

torch.autograd.gradcheck(rmsnorm, (x, w, 1e-6), eps=1e-6, atol=1e-4)

This finite-differences your backward against the forward. Two notes:

Profiling — three views you actually need

ToolWhat it tells youWhen to use
triton.testing.do_benchSingle-kernel wall-clock with proper warmup & noise floor.Iterating on tile sizes.
torch.profilerThe whole training step: which Triton kernels were called, in what order, with stack traces."Is my kernel even running, and how often?"
ncu (Nsight Compute)SM utilisation, memory throughput, register spill, occupancy. Per-kernel deep dive."Why is this kernel slow?"

The workflow:

  1. do_bench says it's slow.
  2. torch.profiler confirms it's the kernel you think (and not, say, an extra copy).
  3. ncu tells you whether you're SM-bound, memory-bound, or stall-bound.

Reading the IR to confirm tensor cores are firing

If you suspect tl.dot isn't lowering to tensor cores, dump the IR:

import os
os.environ['MLIR_ENABLE_DUMP'] = '1'      # before importing triton
# run your kernel once
# then grep the dump:
#   grep -E 'mma|wgmma' /tmp/triton_dump_*.ttgir | head

If you see mma.sync (Ampere/Ada) or wgmma (Hopper), tensor cores are firing. If you see llvm.fma instead, you're on the vector ALU — go back to lesson 05 and check tile dims.

The decision tree — when NOT to write Triton

found a hot region in the profile? no yes skip — find a real lever first does a library cover it? no yes use the library elementwise chain after a matmul? no yes try torch.compile first; Triton if it fails need warp specialisation or async TMA? no yes CUDA / CUTLASS write Triton

Concrete examples of "not Triton"

Concrete examples of "yes, Triton"

Maintenance — what you sign up for

Triton kernels are easier than CUDA, but not free. The ongoing costs:

  1. Triton releases. Each release can change perf characteristics. Re-benchmark on every upgrade.
  2. New hardware. Hopper → Blackwell will require autotune sweeps; existing configs may not transfer.
  3. Backward parity. Forward and backward need to stay in sync as the model evolves.
  4. Numerical drift. If a new model uses a different activation or eps, you may need to re-derive the kernel.

Interactive · production-readiness checklist

Click each box as your kernel passes it. The bottom counter is your "is this safe to ship?" score.

Production readiness

A kernel that ticks all 10 boxes is one you can hand off to someone else without dread.

What this gives you

You've followed the whole arc:

  1. Lessons 01–03: the abstraction (tiles, not threads), the pipeline (Python → PTX), the why.
  2. Lessons 04–06: the DSL — loads/masks, tl.dot, reductions.
  3. Lessons 07–11: five canonical kernels, each one a one-step elaboration on the previous.
  4. Lesson 12: Flash Attention as the synthesis.
  5. Lessons 13–14: performance and shipping.

You can write a fused kernel, derive its bandwidth model, autotune across shapes, write the backward, gradcheck it, profile it, decide if Triton is the right tool, and ship it. That's the full skillset for kernels in 2026.

Where to go from here
Read the Triton tutorials for the kernels we skipped (dropout, group conv, sparse). Read FlashAttention v3 and FlashInfer source to see what production attention looks like. Read CUTLASS for the CUDA side. And when you next see a profile with a fat elementwise chain after a matmul — fuse it.