all_lessons / Triton kernels / lessons / 09 · matmul lesson 09 / 14

Tiled matmul — the canonical GEMM

Every transformer-shaped kernel you ever write has a tiled matmul at its core. This lesson is the one canonical pattern: output tile in registers, K-loop accumulates, boundary masked on three axes, cast on store. Once it's memorised, lessons 10–12 read as elaborations.

The math the kernel implements

For C = A · B with A ∈ ℝM×K, B ∈ ℝK×N:

Cij = Σk Aik · Bkj

We tile the output. Each program owns a BM × BN tile of C, and computes it by streaming BK-wide vertical strips of A and horizontal strips of B:

A M × K BK strip · B K × N BK strip = C M × N BM × BN output tile The K-loop slides both strips along K acc += A_strip · B_strip for k = 0, BK, 2·BK, … Each iteration is one tl.dot. Accumulator stays in registers. When the K-loop ends, acc is the final BM × BN output tile. Cast and store.

The kernel — annotated line by line

@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    sam, sak, sbk, sbn, scm, scn,                    # element strides
    BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
):
    # 1 · which output tile am I?
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # 2 · row/col indices for my tile of C, plus the K-loop indices
    offs_m = pid_m * BM + tl.arange(0, BM)            # (BM,)
    offs_n = pid_n * BN + tl.arange(0, BN)            # (BN,)
    offs_k = tl.arange(0, BK)                         # (BK,)

    # 3 · accumulator in registers — fp32 always
    acc = tl.zeros((BM, BN), dtype=tl.float32)

    # 4 · K-loop: stream BK-wide strips of A and B
    for k in range(0, K, BK):
        a = tl.load(
            a_ptr + offs_m[:, None]*sam + (k + offs_k)[None, :]*sak,
            mask=(offs_m[:, None] < M) & ((k + offs_k)[None, :] < K),
            other=0.,
        ).to(tl.bfloat16)
        b = tl.load(
            b_ptr + (k + offs_k)[:, None]*sbk + offs_n[None, :]*sbn,
            mask=((k + offs_k)[:, None] < K) & (offs_n[None, :] < N),
            other=0.,
        ).to(tl.bfloat16)
        # tensor cores fire here
        acc = tl.dot(a, b, acc)

    # 5 · store the BM × BN tile of C, cast to output dtype
    tl.store(
        c_ptr + offs_m[:, None]*scm + offs_n[None, :]*scn,
        acc.to(tl.bfloat16),
        mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
    )

Five sections. Each line carries one idea. Memorise this shape — every fused-matmul kernel in the rest of the series starts from this skeleton.

Three boundary masks — one for each axis

M, N, and K can each be non-multiples of their tile size. Each load needs a 2D mask that covers both of its axes:

Load2D maskWhy
A tile (BM × BK)(offs_m < M) & (k + offs_k < K)M boundary on rows, K boundary on cols
B tile (BK × BN)(k + offs_k < K) & (offs_n < N)K boundary on rows, N boundary on cols
C store (BM × BN)(offs_m < M) & (offs_n < N)M boundary on rows, N boundary on cols

And other=0. for the loads — the identity of accumulation. Lanes off the end of K contribute zero to the dot product, exactly what we want.

The wrapper

def matmul(a, b):
    M, K = a.shape; K2, N = b.shape
    assert K == K2
    c = torch.empty(M, N, device=a.device, dtype=torch.bfloat16)
    grid = lambda meta: (triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
    matmul_kernel[grid](
        a, b, c, M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c

Autotune configurations that actually matter

@triton.autotune(
    configs=[
        # small N, large M (e.g. attention output projection at low batch)
        triton.Config({'BM':128,'BN':64, 'BK':64}, num_warps=4, num_stages=4),
        # square mid-range
        triton.Config({'BM':128,'BN':128,'BK':32}, num_warps=4, num_stages=3),
        triton.Config({'BM':128,'BN':128,'BK':64}, num_warps=8, num_stages=3),
        # large N (e.g. vocabulary projection)
        triton.Config({'BM':64, 'BN':256,'BK':32}, num_warps=8, num_stages=4),
        # tiny K, wide tiles
        triton.Config({'BM':256,'BN':128,'BK':32}, num_warps=8, num_stages=3),
    ],
    key=['M','N','K'],
)
@triton.jit
def matmul_kernel(...): ...

Five configurations cover most production shapes. The autotuner picks per (M, N, K) and caches.

Where cuBLAS still beats hand-Triton

For vanilla bf16/fp16 GEMM on contiguous tensors, expect cuBLAS to be 5–15% faster than this kernel. The reasons:

  1. Hand-scheduled mma sequences. CUTLASS templates choose the exact order to issue mma.sync instructions to keep the tensor cores fed; Triton's scheduler is good but not surgical.
  2. Tensor Memory Accelerator (TMA) on Hopper. cuBLASLt uses TMA descriptors for asynchronous block loads, overlapping with mma. Triton supports TMA but the lowering isn't always as tight.
  3. Warp specialisation. CUTLASS Hopper kernels use producer-warps that issue TMA and consumer-warps that do mma. Triton's warp-specialisation support is experimental.

So when is Triton GEMM worth writing?

A blunt rule
Don't write a Triton matmul to replace torch.matmul. Write one when something else has to happen inside the kernel: a fused epilogue, a dequantised weight, an unusual layout. Otherwise call cuBLAS.

Pipeline tiles for HBM overlap — the M-major trick

The naive grid (cdiv(M, BM), cdiv(N, BN)) visits output tiles in column-major order (program (0,0), (1,0), (2,0), …). Many adjacent programs read the same row of A, but the L2 cache isn't warm because the launches are interleaved.

A common Triton trick is "swizzling" the grid into super-blocks so that adjacent programs share more A and B reads. We won't dwell here — the autotuner will pick whether to swizzle. The point: a grid is just a numbering of programs, and the numbering affects cache reuse.

Interactive · time vs cuBLAS

Pick shapes; see the predicted Triton time, cuBLAS time, and how much room is left.

Triton vs cuBLAS

Approximation: cuBLAS runs at ~85% of bf16 peak; vanilla Triton at ~70–80%. Numbers are for H100 (bf16 peak ≈ 990 TFLOP/s).

What's next

You can multiply tiles. Lesson 10 takes the reduction skills from lesson 06 and builds a numerically stable softmax in Triton — and shows where the online recurrence comes in when one row doesn't fit a tile. That's the same recurrence that makes lesson 12 (Flash Attention) work.