system_ml / 17 · triton lesson 17 / 19

Triton — the ML kernel DSL

Python that compiles to PTX. A "program" is a block, not a thread. SRAM ↔ HBM is your responsibility. The gap between "I can use cuBLAS" and "I can write CUDA" got narrower in 2021 and Triton is what landed in it.

Where Triton sits

The toolchain of options for "I need a kernel that doesn't exist in the library":

ToolLines of code (typical kernel)ProductivityPeak performance
cuBLAS / cuDNN / library call1HighestVendor-tuned, near-peak
Inductor (auto-generated Triton)0 (you write PyTorch)Implicit~70% of hand-tuned
Triton (hand-written)50–200Moderate~85% of hand-tuned CUDA
CUDA C++ (hand-written)500–5000Lowest~100%

Triton (Tillet et al. 2019, then dramatically reworked at OpenAI from 2021) bridges the productivity gap. It's not a replacement for cuBLAS — the vendor library is still faster for vanilla matmul — but for the long tail of "matmul + custom epilogue", "fused softmax", "FlashAttention-class fusion", it lets one engineer ship in a week what previously took a team a quarter.

The programming model — one program per block

CUDA C++ thinks in terms of threads: thousands of independent threads each running the same kernel, organised into thread blocks of (e.g.) 256 threads. You write code from the perspective of one thread, manage shared memory layout manually, and synchronise with __syncthreads().

Triton flips this. You write code from the perspective of one program, which corresponds to one CUDA thread block. The "data" your program works on is always a block of elements — a 1D, 2D, or 3D tile. There are no individual threads in your code; Triton vectorises the block-level operations across threads automatically.

The canonical Triton vector-add:

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
    pid = tl.program_id(axis=0)                       # which block am I?
    offs = pid * BLOCK + tl.arange(0, BLOCK)          # which elements (vector of BLOCK ints)
    mask = offs < n                                   # boundary check
    x = tl.load(x_ptr + offs, mask=mask)              # vector load
    y = tl.load(y_ptr + offs, mask=mask)
    tl.store(out_ptr + offs, x + y, mask=mask)

# host side:
N = x.numel()
add_kernel[(triton.cdiv(N, 1024),)](x, y, out, N, BLOCK=1024)

Three things to notice:

  1. The kernel handles BLOCK elements at a time. The pid says which slice it owns. offs is a 1D vector of BLOCK indices.
  2. No __syncthreads, no shared-memory declarations. Triton figures out how the vector ops map to threads under the hood.
  3. The launch grid ((triton.cdiv(N, 1024),)) is computed in Python at the call site. Triton ahead-of-time-compiles the kernel for each new shape it sees.

The mental model — SRAM and HBM, explicitly

The whole game is "load a tile into SRAM, do as much arithmetic as you can with it, then write the result back to HBM." In Triton:

The "kernel optimization" question becomes: how many tl.loads and tl.stores do you do per useful arithmetic operation? This is the arithmetic intensity from lesson 01, made tactile. A good Triton kernel has many block-level ops between every load/store; a bad one has one op per load/store.

Animated · one Triton program iteration

Watch one program execute its life-cycle: compute pointers, mask boundary, load a tile from HBM into SRAM, do block-level compute, store back. The blue rectangle at the top represents the kernel's "block of pointers"; the yellow tile in the middle is the SRAM register file. Step through the phases or hit play.

Triton program lifecycle · one block of work
Block size = BLOCK elements. The program computes offs = pid · BLOCK + arange(BLOCK), masks out-of-range lanes, loads from HBM into SRAM tile, applies block-level ops, stores result.
phase
active lanes
HBM traffic
SRAM resident

2D · SRAM ↔ HBM dance

HBM as a strip at the bottom, SRAM as a smaller box at the top. Click play and watch chunks travel up (loads) and down (stores). The bigger the BLOCK_SIZE, the fewer trips you make for the same total work — at the cost of holding more in SRAM at once.

SRAM ↔ HBM dance · slide BLOCK_SIZE
Each "trip" loads BLOCK contiguous elements up to SRAM, the kernel multiplies them by 2 (cheap stand-in for compute), then stores them back. Big BLOCK = few trips, big SRAM footprint. Small BLOCK = many trips, low SRAM occupancy.
trips so far
trips per kernel
SRAM footprint
HBM bytes / trip

Fused softmax — the introductory non-trivial kernel

Softmax across the last dimension of a matrix is a textbook case. The naive PyTorch decomposition is three kernels: subtract max for numerical stability, exp, divide by sum. HBM traffic per row: read row, write (after max), read, write (after exp), read, write — six passes. A fused softmax does one pass:

@triton.jit
def softmax(in_ptr, out_ptr, stride, n_cols, BLOCK: tl.constexpr):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    mask = cols < n_cols
    x = tl.load(in_ptr + row * stride + cols, mask=mask, other=-float('inf'))
    m = tl.max(x, axis=0)                # numerical stability
    e = tl.exp(x - m)
    z = tl.sum(e, axis=0)
    tl.store(out_ptr + row * stride + cols, e / z, mask=mask)

The whole row is in SRAM throughout. HBM traffic per row: 2 · n_cols · dtype_bytes. Speedup over naive PyTorch on a typical transformer: ~3–4×. This is the entry-level Triton tutorial because the win is so visible.

Matmul — where Triton starts to look like real CUDA

A Triton matmul is a few dozen lines that handles tiling, accumulation, and the small autotuning decisions (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages). The structure:

  1. For each output tile (BLOCK_M, BLOCK_N):
  2. Load the corresponding A column-strip and B row-strip in K-direction tiles.
  3. For each K-tile: acc += tl.dot(a_tile, b_tile).
  4. Write the accumulator tile out.

Triton's tl.dot compiles to the GPU's tensor-core matmul instructions. The autotuner sweeps block sizes / pipeline stages / warp counts to find the best variant for your shape. Out of the box, hand-written Triton matmul lands within 5–15% of cuBLAS on H100 — not as fast as the vendor library, but close enough that you can usually fuse a custom epilogue and come out ahead.

The autotuner — Triton's secret weapon

For a given problem, the best block sizes depend on hardware (SM count, SMEM size), tensor shape, dtype, and what comes after. Triton's @triton.autotune decorator runs the kernel under several configs on real inputs at first call, picks the fastest, caches it on disk, and uses it forever. Typical autotune sets sweep:

The autotuner is what makes Triton's productivity claim real: you don't have to know the optimal block size for your kernel; Triton finds it.

2D · code-to-kernels mini-diff

Same matmul, two ways of writing it. On the left, a few dozen lines of Triton with @triton.autotune sweeping BLOCK_M, BLOCK_N, BLOCK_K, num_warps. On the right, the pseudo-CUDA you'd have to write yourself if you committed to one config. Click a config below — the autotuner's pick gets highlighted in green; the diff updates.

Triton autotune vs pseudo-CUDA · per-shape pick
Each row in the right table is one config the autotuner timed. The fastest one is what Triton uses; the equivalent CUDA on the right shows what hand-writing that exact config would look like.
autotuner picks
configs tried
winning TFLOP/s
vs naive config

What Triton makes hard, vs CUDA

How Triton fits with torch.compile (next lesson)

PyTorch 2.x's compiler (Inductor) generates Triton kernels from your PyTorch code. Most elementwise chains and many reductions end up as Inductor-generated Triton. You can also write Triton yourself and call it from compiled code via torch._library.

The practical division of labour:

Interactive · choose your stack

For each pattern, pick the tool you'd reach for, and the widget shows likely effort vs likely speedup vs alternatives. The point is the gradient: more effort = more speedup, but the curve flattens fast.

Tool selection by pattern
Heuristic only. Numbers are typical, not yours. Use as a starting point, not a recipe.
toolrelative effortrelative perfverdict
Takeaway
Triton is the productivity sweet spot for kernels that don't yet exist. Block-level programming hides thread-level details; the autotuner finds reasonable block sizes for you. Inductor uses it to compile arbitrary PyTorch into kernels. Hand-written Triton is the right level for "I need FlashAttention with one custom modification" — fast to write, fast enough to ship.