Tiled matmul — the canonical GEMM
Every transformer-shaped kernel you ever write has a tiled matmul at its core. This lesson is the one canonical pattern: output tile in registers, K-loop accumulates, boundary masked on three axes, cast on store. Once it's memorised, lessons 10–12 read as elaborations.
The math the kernel implements
For C = A · B with A ∈ ℝM×K, B ∈ ℝK×N:
Cij = Σk Aik · Bkj
We tile the output. Each program owns a BM × BN tile of C, and computes it by streaming BK-wide vertical strips of A and horizontal strips of B:
The kernel — annotated line by line
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
sam, sak, sbk, sbn, scm, scn, # element strides
BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
):
# 1 · which output tile am I?
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# 2 · row/col indices for my tile of C, plus the K-loop indices
offs_m = pid_m * BM + tl.arange(0, BM) # (BM,)
offs_n = pid_n * BN + tl.arange(0, BN) # (BN,)
offs_k = tl.arange(0, BK) # (BK,)
# 3 · accumulator in registers — fp32 always
acc = tl.zeros((BM, BN), dtype=tl.float32)
# 4 · K-loop: stream BK-wide strips of A and B
for k in range(0, K, BK):
a = tl.load(
a_ptr + offs_m[:, None]*sam + (k + offs_k)[None, :]*sak,
mask=(offs_m[:, None] < M) & ((k + offs_k)[None, :] < K),
other=0.,
).to(tl.bfloat16)
b = tl.load(
b_ptr + (k + offs_k)[:, None]*sbk + offs_n[None, :]*sbn,
mask=((k + offs_k)[:, None] < K) & (offs_n[None, :] < N),
other=0.,
).to(tl.bfloat16)
# tensor cores fire here
acc = tl.dot(a, b, acc)
# 5 · store the BM × BN tile of C, cast to output dtype
tl.store(
c_ptr + offs_m[:, None]*scm + offs_n[None, :]*scn,
acc.to(tl.bfloat16),
mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
)
Five sections. Each line carries one idea. Memorise this shape — every fused-matmul kernel in the rest of the series starts from this skeleton.
Three boundary masks — one for each axis
M, N, and K can each be non-multiples of their tile size. Each load needs a 2D mask that covers both of its axes:
| Load | 2D mask | Why |
|---|---|---|
| A tile (BM × BK) | (offs_m < M) & (k + offs_k < K) | M boundary on rows, K boundary on cols |
| B tile (BK × BN) | (k + offs_k < K) & (offs_n < N) | K boundary on rows, N boundary on cols |
| C store (BM × BN) | (offs_m < M) & (offs_n < N) | M boundary on rows, N boundary on cols |
And other=0. for the loads — the identity of accumulation. Lanes off the end of K contribute zero to the dot product, exactly what we want.
The wrapper
def matmul(a, b):
M, K = a.shape; K2, N = b.shape
assert K == K2
c = torch.empty(M, N, device=a.device, dtype=torch.bfloat16)
grid = lambda meta: (triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
matmul_kernel[grid](
a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
)
return c
Autotune configurations that actually matter
@triton.autotune(
configs=[
# small N, large M (e.g. attention output projection at low batch)
triton.Config({'BM':128,'BN':64, 'BK':64}, num_warps=4, num_stages=4),
# square mid-range
triton.Config({'BM':128,'BN':128,'BK':32}, num_warps=4, num_stages=3),
triton.Config({'BM':128,'BN':128,'BK':64}, num_warps=8, num_stages=3),
# large N (e.g. vocabulary projection)
triton.Config({'BM':64, 'BN':256,'BK':32}, num_warps=8, num_stages=4),
# tiny K, wide tiles
triton.Config({'BM':256,'BN':128,'BK':32}, num_warps=8, num_stages=3),
],
key=['M','N','K'],
)
@triton.jit
def matmul_kernel(...): ...
Five configurations cover most production shapes. The autotuner picks per (M, N, K) and caches.
Where cuBLAS still beats hand-Triton
For vanilla bf16/fp16 GEMM on contiguous tensors, expect cuBLAS to be 5–15% faster than this kernel. The reasons:
- Hand-scheduled mma sequences. CUTLASS templates choose the exact order to issue mma.sync instructions to keep the tensor cores fed; Triton's scheduler is good but not surgical.
- Tensor Memory Accelerator (TMA) on Hopper. cuBLASLt uses TMA descriptors for asynchronous block loads, overlapping with mma. Triton supports TMA but the lowering isn't always as tight.
- Warp specialisation. CUTLASS Hopper kernels use producer-warps that issue TMA and consumer-warps that do mma. Triton's warp-specialisation support is experimental.
So when is Triton GEMM worth writing?
- Fused epilogue (lesson 08): the bias + activation fusion makes up for the 5–15%.
- Quantised inputs / weights (W4A16, NF4): cuBLAS doesn't cover these dequant-on-load patterns.
- Unusual layouts (block-sparse, low-rank): library kernels don't exist.
- Backward of a fused forward: you need a matching custom backward, which can't be cuBLAS.
torch.matmul. Write one when something else has to happen inside the kernel: a fused epilogue, a dequantised weight, an unusual layout. Otherwise call cuBLAS.
Pipeline tiles for HBM overlap — the M-major trick
The naive grid (cdiv(M, BM), cdiv(N, BN)) visits output tiles in column-major order (program (0,0), (1,0), (2,0), …). Many adjacent programs read the same row of A, but the L2 cache isn't warm because the launches are interleaved.
A common Triton trick is "swizzling" the grid into super-blocks so that adjacent programs share more A and B reads. We won't dwell here — the autotuner will pick whether to swizzle. The point: a grid is just a numbering of programs, and the numbering affects cache reuse.
Interactive · time vs cuBLAS
Pick shapes; see the predicted Triton time, cuBLAS time, and how much room is left.
What's next
You can multiply tiles. Lesson 10 takes the reduction skills from lesson 06 and builds a numerically stable softmax in Triton — and shows where the online recurrence comes in when one row doesn't fit a tile. That's the same recurrence that makes lesson 12 (Flash Attention) work.