all_lessons / Triton kernels / lessons / 05 · tl.dot lesson 05 / 14

tl.dot and tensor cores

One primitive — tl.dot(a, b) — is the gateway to the tensor cores, the hardware that gives modern GPUs ~10× more flops than the vector units. The lesson answers: when does the lowering succeed, what dtypes are legal, and why bf16 in / fp32 accumulate is non-negotiable.

What tl.dot actually compiles to

The primitive looks unassuming:

acc = tl.dot(a, b)         # a is (M, K), b is (K, N), result (M, N)
# or with explicit accumulator:
acc = tl.dot(a, b, acc=acc)   # acc += a @ b

On a tensor-core capable GPU and supported dtypes, Triton lowers this to a sequence of mma.sync (Ampere, Ada) or wgmma (Hopper) instructions — the hardware matmul ops. One mma.sync on Ampere computes a 16×8×16 outer product per warp per cycle, hundreds of times faster per cycle than the vector ALU could compute the same FMA chain.

tl.dot tiles get diced into hardware mma blocks A (BM × BK) 128 × 32 × B (BK × BN) 32 × 128 = C (BM × BN) accumulator in fp32 Each colored sub-block of A and B is a single 16×8×16 mma.sync instruction. The compiler chooses the dicing. You write tl.dot once. The compiler emits dozens of mma's that the warp schedules together.

The shape constraints — when lowering succeeds

Tensor cores want specific tile shapes. The dimensions of the operands to tl.dot must be at least 16 in every axis:

tl.dot(a, b)   # a: (BM, BK), b: (BK, BN)
# Required: BM ≥ 16, BK ≥ 16, BN ≥ 16

And for best occupancy of the mma unit, you want them to be multiples of the hardware mma shape — typically m16n8k16 on Ampere/Ada (Triton issues these in pairs so people often quote it as "16×16×16"), and m64n[8..256]k16 via wgmma on Hopper. Pick BM ≥ 64 on Hopper for clean wgmma lowering.

GPUNative mma shape (M × N × K)Triton calls this
A100 (sm_80)16 × 8 × 16mma.sync.aligned.m16n8k16
H100 (sm_90)64 × N × 16, N ∈ {8..256}wgmma.mma_async.sync
RTX 4090 (sm_89)16 × 8 × 16mma.sync.aligned.m16n8k16

If your tile dims are too small (e.g. BK=8), Triton silently falls back to the vector ALU — you get correct results but a small fraction of the flops. The cure is a bigger BK.

Dtype rules — bf16 in, fp32 accumulate

The mma instructions take low-precision inputs and accumulate in higher precision:

Input dtypesAllowed accumulatorNotes
bf16 × bf16fp32The modern default. Use this unless you have a reason not to.
fp16 × fp16fp32 (or fp16)fp16 accumulate is faster but can overflow on long-K matmuls.
fp8 × fp8fp32Hopper+. Triton supports e4m3 and e5m2.
int8 × int8int32Hopper+. For quantised inference.
tf32 × tf32fp32fp32 inputs, internally truncated to TF32 (19-bit: 8-bit exponent + 10-bit mantissa). The default for fp32 matmul on Ampere+.

Concretely:

acc = tl.zeros((BM, BN), dtype=tl.float32)         # ← fp32 accumulator
for k in range(0, K, BK):
    a = tl.load(...).to(tl.bfloat16)                # ← bf16 inputs
    b = tl.load(...).to(tl.bfloat16)
    acc = tl.dot(a, b, acc)                          # mma.sync emits bf16×bf16→fp32
y = acc.to(tl.bfloat16)                              # cast back when storing
Why fp32 accumulate is non-negotiable
Accumulating bf16 into bf16 over a K-loop of length 4096 produces drift visible in loss curves and broken downstream inference. The mantissa of bf16 is 7 bits; you need 22+ bits of accumulated precision for any non-trivial GEMM. Always tl.zeros(..., dtype=tl.float32). The single most common silent bug in Triton GEMMs.

The K-loop pattern

For matmuls where K is much larger than your tile's BK, you loop over K:

acc = tl.zeros((BM, BN), dtype=tl.float32)
for k in range(0, K, BK):
    a = tl.load(a_ptr + offs_m[:, None]*stride_am + (k + offs_k)[None, :]*stride_ak,
                mask=(offs_m[:, None] < M) & ((k + offs_k)[None, :] < K), other=0.).to(tl.bfloat16)
    b = tl.load(b_ptr + (k + offs_k)[:, None]*stride_bk + offs_n[None, :]*stride_bn,
                mask=((k + offs_k)[:, None] < K) & (offs_n[None, :] < N), other=0.).to(tl.bfloat16)
    acc = tl.dot(a, b, acc)

Note:

Output: cast on store

The accumulator is fp32. The output tensor is typically bf16. Cast at the final store:

tl.store(c_ptr + ..., acc.to(tl.bfloat16), mask=...)

Storing fp32 to a bf16 output without casting will write four bytes per element to a two-byte slot — corruption, plus a runtime error if you're lucky.

When tl.dot falls back to non-tensor-core

If the compiler can't lower to mma, it generates a vector-ALU FMA loop instead. Triggers:

The compiler doesn't error out; it silently gives you slow code. The way to confirm tensor cores are firing is to dump the PTX (lesson 14) and grep for mma.sync / wgmma. Or just compare throughput to the cuBLAS reference — if you're 10× slower, you're on the ALU.

Interactive · is your tile shape tensor-core-friendly?

Set the tile dims and target GPU. See whether the lowering will hit tensor cores and what mma shape it dices into.

tl.dot lowering preview

Real lowering depends on layouts and the autotuner. This widget approximates the rule: tile dims ≥ hardware mma shape on every axis.

What's next

You can do dense matmuls. The other half of every transformer kernel is reductions — sums, maxes, and the softmax pattern. Lesson 06 introduces tl.sum and friends, and previews the online-softmax recurrence that makes Flash Attention possible.