The tile programming model
CUDA programs threads; Triton programs tiles. This one shift is what defines the language. Once you see what the compiler is doing on your behalf, every Triton primitive will read naturally.
The two universes side by side
Take the simplest possible kernel: c[i] = a[i] + b[i] for N elements.
| CUDA — SIMT | Triton — SPMD on tiles |
|---|---|
You write what one of the 128 threads does. The launch creates |
You write what one program does to a BLOCK of indices. The launch creates |
Same memory traffic, same arithmetic, same hardware utilisation. Different abstraction.
What "tile" means precisely
A tile is a fixed-size vector (or 2D / 3D matrix) of indices whose size is known at compile time:
offs = tl.arange(0, BLOCK) # vector of BLOCK integers
# BLOCK must be a constexpr power of two (typically 16, 32, 64, 128, 256)
# At compile time the compiler knows: BLOCK=128, so this is a 128-lane SIMD vector
Operations on the tile are SIMD-style: x + y, tl.exp(x), tl.dot(a, b) all act lane-wise (or, for tl.dot, on the whole tile at once).
SPMD-on-tiles, not SIMT
threadIdx. In Triton's SPMD-on-tiles model you visualise one program, whose variables are tiles, stepping through your code. The compiler turns each tile op into the right number of warp-level instructions.
Concretely: when you write
x = tl.load(a + offs, mask=m) # offs is a 128-vector
Triton emits 4 coalesced 32-wide global loads (one per warp), each pulling 32 contiguous floats from HBM. You did not have to specify "this thread loads index i". You wrote tile-level intent.
When you write
s = tl.sum(x, axis=0) # reduce 128-vector to scalar
Triton emits intra-warp shuffles (__shfl_xor-style) inside each warp to produce 4 partial sums, then an inter-warp reduction through shared memory to fold them to one. You wrote one line.
What you give up by hiding the warp
Two kinds of optimisation get harder:
- Warp specialisation. "Warps 0–3 do TMA loads; warps 4–7 do mma compute" is a producer/consumer pattern used in FlashAttention v3 and CUTLASS Hopper kernels. Triton has some support for this (
num_consumer_groups,num_buffers_warp_spec) but it's experimental and limited. - Sub-warp scheduling. "Lanes 0–15 do A, lanes 16–31 do B" — say, for a matmul that wants 16×16 mma blocks — is something you can do in CUDA with byte-level instructions. In Triton you trust the compiler to pick the right mma shape.
For 95% of fused kernels you don't need either. The compiler does the right thing.
What you gain by hiding the warp
- Code that reads like the math. A softmax kernel is 12 lines, not 90.
- Tile sizes are constexpr knobs. Change
BLOCK=128toBLOCK=256and recompile — the autotuner does this automatically. You don't restructure thread blocks. - Boundary conditions are masks, not branches.
mask=offs < Nis a tile-level boolean; lanes outside N become no-ops without warp divergence.
Interactive · what does each warp lane do?
Set the tile size and number of warps. Watch which lane covers which element of the input tile.
The five questions Triton asks the compiler to answer for you
| Question | CUDA answer | Triton answer |
|---|---|---|
| How many threads per block? | You set blockDim. | You set num_warps in the config; compiler picks lane assignment. |
| Which thread reads which element? | You compute threadIdx-based offsets. | You write a tile of offsets; compiler vectorises. |
| When do warps sync? | You insert __syncthreads(). | The compiler inserts barriers around tl.dot, reductions, and stores. |
| How is shared memory laid out? | You declare __shared__ and pick swizzle. | The compiler allocates SMEM for tl.dot operands; swizzle picked to avoid bank conflicts. |
| How do async copies overlap with compute? | You write cp.async pipelines by hand. | The compiler builds a software pipeline of depth num_stages. |
If any row makes you nervous — "but I want to control that!" — Triton is wrong for that kernel and you should reach for CUDA / CUTLASS. For everything else, Triton is the right tool because those answers are routinely good enough.
What's next
You've seen the abstraction. The next question is mechanical: how does Python source become PTX? Lesson 03 walks the compile pipeline — what @triton.jit does at import time vs first-call time, what tl.constexpr bakes in, and how the autotuner caches one compiled binary per shape key.