The execution model
From @triton.jit to PTX in four stages, plus the grid that launches the kernel. The lesson answers: what runs when, what does tl.constexpr actually freeze, and why the first call to a new shape is slow.
The pipeline in one diagram
What "compile" means in Triton
Triton compilation is multi-versioned: there's one Python function in your source, and many compiled binaries — one per unique combination of:
- constexpr values (tile sizes, dtypes marked
tl.constexpr), - argument dtypes (bf16 vs fp16 vs fp32),
- pointer alignments (Triton specialises on whether pointers are 16-byte aligned),
- num_warps and num_stages from the config (autotune sweeps these),
Hit the function with new values for any of those? You get a new compile, cached for next time. Hit it with the same values? Cache lookup, ~µs.
tl.constexpr — the line that splits compile time from runtime
@triton.jit
def my_kernel(
a_ptr, b_ptr, c_ptr, # ← runtime args (pointers)
N, # ← runtime arg (scalar)
BLOCK: tl.constexpr, # ← compile time
DTYPE: tl.constexpr = tl.float32, # ← compile time
):
...
Mark any argument tl.constexpr and:
- Its value is known when the kernel compiles. Triton bakes it into the generated code.
- Loops over it can be unrolled.
tl.arange(0, BLOCK)becomes a fixed-size SIMD vector. - It contributes to the cache key — a new value forces a recompile.
Without tl.constexpr, the value is runtime: tile-size loops can't be unrolled, vector lengths can't be vectorised, and you'll get a slow generic kernel. Mark tile sizes, dtypes, and any "shape of the work" constants as constexpr.
The grid — your problem decomposition
The grid is how you tell Triton "how many programs to launch and how to index them". It's a tuple of 1, 2, or 3 axes, computed at launch time:
def add_fn(a, b):
N = a.numel()
c = torch.empty_like(a)
BLOCK = 1024 # tile size
grid = (triton.cdiv(N, BLOCK),) # 1D grid
add_kernel[grid](a, b, c, N, BLOCK=BLOCK)
return c
Or with autotune, where BLOCK is chosen by the autotuner and the grid must read it from the meta dict:
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),) # grid depends on BLOCK
Inside the kernel you read your coords with tl.program_id(axis):
pid_m = tl.program_id(0) # which row tile
pid_n = tl.program_id(1) # which column tile
Launch overhead — and how to amortise it
A Triton kernel launch costs ~5–20µs. That's fine for a kernel that runs for ms, terrible for a kernel that runs for µs. Two tactics keep this in check:
- Bigger tiles. Doing more work per program means fewer programs, fewer launches per outer call. A vector add over 16M elements with
BLOCK=1024is 16K programs in one launch. - CUDA graphs. Capture a sequence of Triton launches into a single replayable graph. The launch overhead becomes O(1) per replay, not O(launches).
Reading the cache
Triton caches compiled kernels under ~/.triton/cache/ (override with TRITON_CACHE_DIR). Each compiled binary is one subdirectory containing the IR dumps and the PTX/SASS. If you change the kernel source, the cache invalidates by content hash. If you don't, but the kernel is misbehaving, deleting the cache is a cheap reset:
rm -rf ~/.triton/cache # forces full recompile next call
You can also dump the IR for any kernel by setting MLIR_ENABLE_DUMP=1 — useful when you suspect the compiler is generating something different from what you wrote (lesson 14 will use this).
Interactive · what gets compiled when?
Toggle the source change and the call site change. See which calls hit the cache and which trigger a recompile.
What's next
You know what happens around the kernel. Now you can write inside one. Lesson 04 introduces the first DSL primitive: tl.load with masks. This is the op that moves bytes from HBM into your tile registers — and the one whose mask is the most common source of correctness bugs in Triton.