all_lessons/gpu_kernel_serving/22 · writing tritonlesson 22 / 24

Writing Triton kernels

Triton trades CUDA's thread-level control for a per-program tile language with an autotuner. You give up some last-mile control over mma scheduling and warp specialization in exchange for far less code, automatic shape tuning, and easier maintenance. For 95 % of "I need a fused kernel" cases, that trade is the right one.

The question this lesson answers

You profiled (lesson 20), located a bandwidth-bound chain of small ops (lesson 21), and decided fusion is the lever. Now what? You could write CUDA, but you don't need 100 % of peak — you need a fused kernel today that's within 80–90 % of optimal, autotunable across shapes, easy to maintain. That's exactly what Triton is for.

Triton's programming model in one diagram

Triton "program"1D / 2D / 3D grid idhandles one tile of work tileBLOCK_M × BLOCK_N etc.vectorized loads/stores on-chip mathtl.dot, tl.sum, tl.expno explicit warps autotuneconfigs per shape keycached compiled binaries vs CUDA Triton hides threads and warps. You write tile-level code; the compiler maps it. You give up some last-mile control (mma scheduling, warp specialization) for far less code and an autotuner. when Triton wins Custom fusions, novel reductions, kernels that need to track new layouts. cuBLAS/CUTLASS will still beat hand-Triton for vanilla bf16 matmul because they hand-tune mma sequences and async copies.

The Triton DSL, just the parts you need

PrimitivePurposeNotes
@triton.jitJIT-compile this Python function as a kernel.The function is staged — runs at compile time, emits PTX.
tl.program_id(axis)Which tile is this program.Up to 3 axes. Map your problem decomposition here.
tl.arange(0, BLOCK)Vector of compile-time-known length.The "thread vector." Operations on it are SIMD.
tl.load(ptr, mask, other)Coalesced load with predicate.Out-of-bounds lanes get other. Always pass a mask for the boundary tile.
tl.store(ptr, val, mask)Coalesced store with predicate.Same boundary contract.
tl.dot(a, b)Tile matmul. Hits tensor cores when shapes allow.The right primitive for GEMM-shaped inner loops.
tl.sum / tl.max / tl.min / tl.cumsumReductions across a tile axis.Compiles to warp shuffles + SMEM.
tl.exp / tl.log / tl.rsqrtMath intrinsics (fast hardware versions).Lower precision than std libm but usually fine.
tl.constexprMarks an argument as compile-time.Tile sizes, dtypes. Anything that affects code generation.
@triton.autotune(configs, key=...)Compile several configs; pick the fastest on first call per key.The autotuner is your friend. Use it.
num_warps, num_stages (config knobs)num_warps: warps per program (occupancy vs work-per-warp). num_stages: depth of the software pipeline Triton emits for overlapping memory loads with compute.Both are config-level, not per-instruction. The autotuner sweeps them.

A complete fused kernel: GELU + linear bias

Suppose downstream we want y = gelu(x @ w + b). With eager PyTorch that is 3 launches (matmul → add → gelu). A Triton kernel folds the bias and activation into the GEMM's epilogue.

import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
        triton.Config({'BM': 128, 'BN': 256, 'BK': 32}, num_warps=8, num_stages=3),
        triton.Config({'BM': 64,  'BN': 64,  'BK': 64}, num_warps=4, num_stages=4),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def gelu_linear(
    x_ptr, w_ptr, b_ptr, y_ptr,
    M, N, K,
    sxm, sxk, swk, swn, sym, syn,         # strides
    BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BM + tl.arange(0, BM)
    offs_n = pid_n * BN + tl.arange(0, BN)
    offs_k = tl.arange(0, BK)

    acc = tl.zeros((BM, BN), dtype=tl.float32)

    # outer-product K-loop with on-chip accumulation
    for k in range(0, K, BK):
        a = tl.load(x_ptr + offs_m[:, None]*sxm + (k+offs_k)[None, :]*sxk,
                    mask=(offs_m[:, None] < M) & ((k+offs_k)[None, :] < K),
                    other=0.).to(tl.bfloat16)
        bw = tl.load(w_ptr + (k+offs_k)[:, None]*swk + offs_n[None, :]*swn,
                     mask=((k+offs_k)[:, None] < K) & (offs_n[None, :] < N),
                     other=0.).to(tl.bfloat16)
        acc += tl.dot(a, bw)             # tensor-core matmul on tile

    # fused epilogue: bias + GELU
    bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.).to(tl.float32)
    z = acc + bias[None, :]
    # gelu approx: 0.5 * z * (1 + tanh(sqrt(2/pi) * (z + 0.044715 * z^3)))
    c = 0.7978845608    # sqrt(2/pi)
    y = 0.5 * z * (1.0 + tl.tanh(c * (z + 0.044715 * z*z*z)))

    tl.store(y_ptr + offs_m[:, None]*sym + offs_n[None, :]*syn,
             y.to(tl.bfloat16),
             mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))

The wrapper that makes it a PyTorch op

def gelu_linear_fn(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    M, K = x.shape
    K2, N = w.shape
    assert K == K2
    y = torch.empty(M, N, device=x.device, dtype=x.dtype)
    grid = lambda meta: (triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
    gelu_linear[grid](
        x, w, b, y,
        M, N, K,
        x.stride(0), x.stride(1),
        w.stride(0), w.stride(1),
        y.stride(0), y.stride(1),
    )
    return y

For backward, register a torch.autograd.Function with explicit forward and backward Triton kernels — or write a single forward and let torch.compile (lesson 23) handle the backward by tracing through your op.

The design loop

1 · hypothesis"fusion will help" 2 · sketch kernelprologue/load/math/epilogue 3 · correctnessvs PyTorch reference 4 · benchmarkdo_bench / autotune 5 · ncuwhy slow? refine tiles, autotune configs, fix register spill, etc.

What autotune actually does

The autotuner stores a dict keyed by your key=[...] values. For each new key, it compiles every config, runs each one a few times, picks the fastest, and caches the choice. Subsequent calls with the same key skip the search. Two practical consequences:

# benchmark a single config
ms = triton.testing.do_bench(lambda: gelu_linear_fn(x, w, b), warmup=25, rep=100)
print(f"{ms:.3f} ms")

Common pitfalls and how to spot them

PitfallSymptomFix
No boundary maskOut-of-bounds reads → segfault or wrong outputs at irregular shapes.Always pass mask= and other= for non-multiples of BLOCK.
Wrong stride argument orderGarbled output for non-contiguous tensors.Pass x.stride(0), x.stride(1) in the same order the kernel multiplies.
Accumulating in bf16Loss drift; large-K matmuls go visibly wrong.Always accumulate in fp32 (tl.zeros(..., dtype=tl.float32)).
Forgot tl.constexprTile sizes become runtime values; compiler can't vectorize.Mark BM, BN, BK, dtypes as tl.constexpr.
Register spillAchieved occupancy <25 %; ncu shows "No Eligible" stalls.Smaller tile, fewer accumulators, or set maxnreg.
SMEM bank conflicts (non-dot patterns)SM throughput much lower than peak in reductions or custom SMEM scratchpad usage.Triton's layout planner picks swizzle patterns to avoid bank conflicts in tl.dot automatically. If you see conflicts, they are almost always in custom reductions; pad strides or restructure the access pattern.
Autotune never re-runsCached config is stale after you change the kernel.Clear ~/.triton/cache or bump a constant in the function.
Reduction across blocksYou wrote tl.atomic_add across many programs and lost determinism.Use a two-pass scheme: per-block partial → second kernel to combine.

When NOT to write Triton

Use Triton when (a) no library covers the operation, (b) the operation is bandwidth-bound and fusion saves bytes, or (c) you need autotuning across many shapes and don't have time to hand-tune CUDA per shape.

Interactive · is Triton the right tool?

Set how much of the step the candidate region takes, how unique the op is, and how irregular the shapes are. The widget recommends Triton, a library call, or "don't bother."

Triton vs library vs no-op

Triton's sweet spot: novel fusion, bandwidth-bound, multiple shapes. Outside that, prefer a library or skip.

What this gives you for the next lesson

You can write a fused kernel by hand. Lesson 23 covers the lazy alternative: let torch.compile generate Triton for you. That's right far more often than people think — and when it's not, you have the skills from this lesson to step in.