Triton — Writing GPU Kernels in Python
A linearized tour of OpenAI Triton, the kernel DSL — built so you understand the tile programming model first, then build kernels (vector add → matmul → softmax → flash attention) from those primitives.
This series of fourteen interactive lessons unwraps Triton from scratch. Part I (lessons 01–03) covers the execution model: why Triton exists, what a "program" is, and how Python source becomes PTX. Part II (lessons 04–06) covers the DSL: pointers, masks, tl.dot, and reductions — every primitive you'll touch. Part III (lessons 07–11) walks you through five real kernels, each one a one-step elaboration on the previous: vector add, fused linear+activation, tiled matmul, online softmax, fused norm. Part IV (lesson 12) is the flagship: Flash Attention as a synthesis of every primitive. Part V (lessons 13–14) is performance and production: autotune, pipelining, profiling, backward passes, and the decision tree of when not to write Triton. Each lesson has at least one interactive widget so you can grab a knob and feel the consequence.
The model you're learning
Triton is a Python-embedded DSL with one central abstraction: a program handles one tile of work. You write tile-level code; the compiler maps tiles onto warps, schedules loads against compute, and picks layouts. Hover a stage to see its job.
Part I · The model (lessons 01–03 · why tiles, not threads)
@triton.jit actually does. What tl.constexpr controls. The grid is your problem decomposition.Part II · The DSL (lessons 04–06 · every primitive you'll touch)
Triton has a small DSL. By the end of these three lessons you'll know every op you need to write 90% of kernels — and what each one compiles to.
tl.load and tl.store with predicates. Why every tile needs a mask and what happens when it doesn't. Strides as the address calculator. Live "coalesce-or-not" address visualiser.tl.dot lowers to mma/wgmma and when it falls back to FMA. Accumulation dtype rules: bf16 in, fp32 accumulate. The shape constraints that decide whether you hit tensor cores.tl.sum, tl.max, tl.cumsum. How a tile reduction lowers to warp shuffles + SMEM. The online softmax recurrence (Milakov-Gimelshein) — your first taste of why Flash Attention works.Part III · Building real kernels (lessons 07–11 · five canonical examples in order of complexity)
Five kernels you'd ship in a production stack. Each is a one-step elaboration on the previous — read in order and the last one (RMSNorm) is straightforward; skip the order and it isn't.
1D, 1 op epilogue fusion 2D + K-loop
vector_add ───────▶ fused_linear_act ───────▶ tiled_matmul
│
│ online reduction
▼
softmax
│
│ stat + scale fused
▼
rms_norm
@jit, grid, launch, mask the tail, autotune one config. Verify against PyTorch. Benchmark with do_bench. The minimum a Triton kernel can be.Part IV · The flagship (lesson 12 · everything composes)
Part V · Performance & production (lessons 13–14 · shipping it)
key, configs, cache. num_warps vs num_stages demystified — software pipelining is what makes tl.dot-heavy kernels fast. The full pitfall checklist: register spill, bf16 accum, missing mask, stale cache.torch.autograd.Function with explicit Triton forward + backward. Profiling: do_bench, Nsight Compute, dumping TTGIR/PTX. The decision tree: Triton vs CUDA vs torch.compile vs library.How to use this
- Linearly. Each lesson assumes the previous. Lesson 12 (Flash Attention) literally calls every primitive from lessons 04–10; skip them and it won't read.
- Run every kernel. The lessons include complete, runnable code. Paste it into a Colab with a T4 or better and time it. The point is the wall-clock surprise.
- Touch every knob. Every widget has a setting that makes the kernel wrong or slow. Find it. The bugs are the lesson.