Triton — the ML kernel DSL
Python that compiles to PTX. A "program" is a block, not a thread. SRAM ↔ HBM is your responsibility. The gap between "I can use cuBLAS" and "I can write CUDA" got narrower in 2021 and Triton is what landed in it.
Where Triton sits
The toolchain of options for "I need a kernel that doesn't exist in the library":
| Tool | Lines of code (typical kernel) | Productivity | Peak performance |
|---|---|---|---|
| cuBLAS / cuDNN / library call | 1 | Highest | Vendor-tuned, near-peak |
| Inductor (auto-generated Triton) | 0 (you write PyTorch) | Implicit | ~70% of hand-tuned |
| Triton (hand-written) | 50–200 | Moderate | ~85% of hand-tuned CUDA |
| CUDA C++ (hand-written) | 500–5000 | Lowest | ~100% |
Triton (Tillet et al. 2019, then dramatically reworked at OpenAI from 2021) bridges the productivity gap. It's not a replacement for cuBLAS — the vendor library is still faster for vanilla matmul — but for the long tail of "matmul + custom epilogue", "fused softmax", "FlashAttention-class fusion", it lets one engineer ship in a week what previously took a team a quarter.
The programming model — one program per block
CUDA C++ thinks in terms of threads: thousands of independent threads each running the same kernel, organised into thread blocks of (e.g.) 256 threads. You write code from the perspective of one thread, manage shared memory layout manually, and synchronise with __syncthreads().
Triton flips this. You write code from the perspective of one program, which corresponds to one CUDA thread block. The "data" your program works on is always a block of elements — a 1D, 2D, or 3D tile. There are no individual threads in your code; Triton vectorises the block-level operations across threads automatically.
The canonical Triton vector-add:
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
pid = tl.program_id(axis=0) # which block am I?
offs = pid * BLOCK + tl.arange(0, BLOCK) # which elements (vector of BLOCK ints)
mask = offs < n # boundary check
x = tl.load(x_ptr + offs, mask=mask) # vector load
y = tl.load(y_ptr + offs, mask=mask)
tl.store(out_ptr + offs, x + y, mask=mask)
# host side:
N = x.numel()
add_kernel[(triton.cdiv(N, 1024),)](x, y, out, N, BLOCK=1024)
Three things to notice:
- The kernel handles
BLOCKelements at a time. Thepidsays which slice it owns.offsis a 1D vector ofBLOCKindices. - No
__syncthreads, no shared-memory declarations. Triton figures out how the vector ops map to threads under the hood. - The launch grid (
(triton.cdiv(N, 1024),)) is computed in Python at the call site. Triton ahead-of-time-compiles the kernel for each new shape it sees.
The mental model — SRAM and HBM, explicitly
The whole game is "load a tile into SRAM, do as much arithmetic as you can with it, then write the result back to HBM." In Triton:
tl.loadmoves a tile from HBM to SRAM (and from there into registers).- Block-level operations (
+,*,tl.dot,tl.exp) act on those SRAM tiles. They produce new SRAM tiles. tl.storemoves a tile back to HBM.
The "kernel optimization" question becomes: how many tl.loads and tl.stores do you do per useful arithmetic operation? This is the arithmetic intensity from lesson 01, made tactile. A good Triton kernel has many block-level ops between every load/store; a bad one has one op per load/store.
Animated · one Triton program iteration
Watch one program execute its life-cycle: compute pointers, mask boundary, load a tile from HBM into SRAM, do block-level compute, store back. The blue rectangle at the top represents the kernel's "block of pointers"; the yellow tile in the middle is the SRAM register file. Step through the phases or hit play.
2D · SRAM ↔ HBM dance
HBM as a strip at the bottom, SRAM as a smaller box at the top. Click play and watch chunks travel up (loads) and down (stores). The bigger the BLOCK_SIZE, the fewer trips you make for the same total work — at the cost of holding more in SRAM at once.
Fused softmax — the introductory non-trivial kernel
Softmax across the last dimension of a matrix is a textbook case. The naive PyTorch decomposition is three kernels: subtract max for numerical stability, exp, divide by sum. HBM traffic per row: read row, write (after max), read, write (after exp), read, write — six passes. A fused softmax does one pass:
@triton.jit
def softmax(in_ptr, out_ptr, stride, n_cols, BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < n_cols
x = tl.load(in_ptr + row * stride + cols, mask=mask, other=-float('inf'))
m = tl.max(x, axis=0) # numerical stability
e = tl.exp(x - m)
z = tl.sum(e, axis=0)
tl.store(out_ptr + row * stride + cols, e / z, mask=mask)
The whole row is in SRAM throughout. HBM traffic per row: 2 · n_cols · dtype_bytes. Speedup over naive PyTorch on a typical transformer: ~3–4×. This is the entry-level Triton tutorial because the win is so visible.
Matmul — where Triton starts to look like real CUDA
A Triton matmul is a few dozen lines that handles tiling, accumulation, and the small autotuning decisions (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages). The structure:
- For each output tile (BLOCK_M, BLOCK_N):
- Load the corresponding A column-strip and B row-strip in K-direction tiles.
- For each K-tile:
acc += tl.dot(a_tile, b_tile). - Write the accumulator tile out.
Triton's tl.dot compiles to the GPU's tensor-core matmul instructions. The autotuner sweeps block sizes / pipeline stages / warp counts to find the best variant for your shape. Out of the box, hand-written Triton matmul lands within 5–15% of cuBLAS on H100 — not as fast as the vendor library, but close enough that you can usually fuse a custom epilogue and come out ahead.
The autotuner — Triton's secret weapon
For a given problem, the best block sizes depend on hardware (SM count, SMEM size), tensor shape, dtype, and what comes after. Triton's @triton.autotune decorator runs the kernel under several configs on real inputs at first call, picks the fastest, caches it on disk, and uses it forever. Typical autotune sets sweep:
- BLOCK_M, BLOCK_N, BLOCK_K — the tile shapes.
- num_warps — how many warps per program (1, 2, 4, 8). More warps = more parallelism, less SRAM per warp.
- num_stages — software pipelining depth. More stages overlap loads with compute better but use more registers.
The autotuner is what makes Triton's productivity claim real: you don't have to know the optimal block size for your kernel; Triton finds it.
2D · code-to-kernels mini-diff
Same matmul, two ways of writing it. On the left, a few dozen lines of Triton with @triton.autotune sweeping BLOCK_M, BLOCK_N, BLOCK_K, num_warps. On the right, the pseudo-CUDA you'd have to write yourself if you committed to one config. Click a config below — the autotuner's pick gets highlighted in green; the diff updates.
What Triton makes hard, vs CUDA
- Asynchronous copies and double buffering — Triton exposes
num_stagesas a knob, but you don't write thecp.asyncinstructions yourself. You're at the mercy of the compiler. On Hopper specifically, the "warp-specialisation" pattern that gets the absolute best matmul throughput is awkward to express. - Persistent kernels. Patterns where a kernel runs across all SMs, picks work off a queue, and keeps the SMs alive for many iterations (common in serving / dynamic batching) — Triton can do them but it's not idiomatic.
- Cross-block coordination. CUDA can use
cooperative_groupsand grid-wide barriers. Triton's model is "every program is independent" — you can't have programs talk to each other within one kernel.
How Triton fits with torch.compile (next lesson)
PyTorch 2.x's compiler (Inductor) generates Triton kernels from your PyTorch code. Most elementwise chains and many reductions end up as Inductor-generated Triton. You can also write Triton yourself and call it from compiled code via torch._library.
The practical division of labour:
- Inductor-generated Triton — elementwise chains, simple reductions, the long tail of "this was three eager ops, now it's one kernel."
- Hand-written Triton — operator-class fusions (FlashAttention, fused MoE, fused RoPE), kernels with shapes that Inductor's autotuner doesn't explore.
- Hand-written CUDA — the very top of the performance budget (FlashAttention itself, persistent kernels in TRT-LLM, cuBLAS replacements).
Interactive · choose your stack
For each pattern, pick the tool you'd reach for, and the widget shows likely effort vs likely speedup vs alternatives. The point is the gradient: more effort = more speedup, but the curve flattens fast.