all_lessons / Triton kernels / lessons / 03 · execution model lesson 03 / 14

The execution model

From @triton.jit to PTX in four stages, plus the grid that launches the kernel. The lesson answers: what runs when, what does tl.constexpr actually freeze, and why the first call to a new shape is slow.

The pipeline in one diagram

Python AST@triton.jitdef my_kernel(...)happens at import Triton IR (.ttir)tile-level opsno layout choices yetspecialised per call TritonGPU IR (.ttgir)layouts, warps,pipeline, SMEM swizzlewhere the magic happens LLVM IRtarget-specificregister allocationscheduling PTX → SASScached binaryper shape key~/.triton/cache When does each stage run? @jit runs at import — just stages the Python AST. The lowering chain runs lazily on the first call with a new (constexpr values, dtypes) signature. The result is cached.

What "compile" means in Triton

Triton compilation is multi-versioned: there's one Python function in your source, and many compiled binaries — one per unique combination of:

Hit the function with new values for any of those? You get a new compile, cached for next time. Hit it with the same values? Cache lookup, ~µs.

First call per shape is slow
Autotuning a Triton kernel for a new shape can take hundreds of milliseconds (sometimes seconds for big sweeps). Always run a warmup pass before timing — or you'll measure the compiler, not the kernel.

tl.constexpr — the line that splits compile time from runtime

@triton.jit
def my_kernel(
    a_ptr, b_ptr, c_ptr,         # ← runtime args (pointers)
    N,                            # ← runtime arg (scalar)
    BLOCK: tl.constexpr,          # ← compile time
    DTYPE: tl.constexpr = tl.float32,  # ← compile time
):
    ...

Mark any argument tl.constexpr and:

Without tl.constexpr, the value is runtime: tile-size loops can't be unrolled, vector lengths can't be vectorised, and you'll get a slow generic kernel. Mark tile sizes, dtypes, and any "shape of the work" constants as constexpr.

The grid — your problem decomposition

The grid is how you tell Triton "how many programs to launch and how to index them". It's a tuple of 1, 2, or 3 axes, computed at launch time:

def add_fn(a, b):
    N = a.numel()
    c = torch.empty_like(a)
    BLOCK = 1024                                        # tile size
    grid = (triton.cdiv(N, BLOCK),)                     # 1D grid
    add_kernel[grid](a, b, c, N, BLOCK=BLOCK)
    return c

Or with autotune, where BLOCK is chosen by the autotuner and the grid must read it from the meta dict:

grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)   # grid depends on BLOCK
1D grid — vector ops grid = (cdiv(N, BLOCK),) pid=0 pid=1 pid=2 pid=3 2D grid — matrix ops grid = (cdiv(M, BM), cdiv(N, BN)) (0,0) (0,1) (0,2) (0,3) (1,0) (2,0)

Inside the kernel you read your coords with tl.program_id(axis):

pid_m = tl.program_id(0)      # which row tile
pid_n = tl.program_id(1)      # which column tile

Launch overhead — and how to amortise it

A Triton kernel launch costs ~5–20µs. That's fine for a kernel that runs for ms, terrible for a kernel that runs for µs. Two tactics keep this in check:

  1. Bigger tiles. Doing more work per program means fewer programs, fewer launches per outer call. A vector add over 16M elements with BLOCK=1024 is 16K programs in one launch.
  2. CUDA graphs. Capture a sequence of Triton launches into a single replayable graph. The launch overhead becomes O(1) per replay, not O(launches).

Reading the cache

Triton caches compiled kernels under ~/.triton/cache/ (override with TRITON_CACHE_DIR). Each compiled binary is one subdirectory containing the IR dumps and the PTX/SASS. If you change the kernel source, the cache invalidates by content hash. If you don't, but the kernel is misbehaving, deleting the cache is a cheap reset:

rm -rf ~/.triton/cache    # forces full recompile next call

You can also dump the IR for any kernel by setting MLIR_ENABLE_DUMP=1 — useful when you suspect the compiler is generating something different from what you wrote (lesson 14 will use this).

Interactive · what gets compiled when?

Toggle the source change and the call site change. See which calls hit the cache and which trigger a recompile.

Cache key simulator

A new compile happens when any element of the cache key changes: (source_hash, constexpr_vals, dtypes, num_warps, num_stages, alignment). Everything else hits the cache.

What's next

You know what happens around the kernel. Now you can write inside one. Lesson 04 introduces the first DSL primitive: tl.load with masks. This is the op that moves bytes from HBM into your tile registers — and the one whose mask is the most common source of correctness bugs in Triton.