tl.dot and tensor cores
One primitive — tl.dot(a, b) — is the gateway to the tensor cores, the hardware that gives modern GPUs ~10× more flops than the vector units. The lesson answers: when does the lowering succeed, what dtypes are legal, and why bf16 in / fp32 accumulate is non-negotiable.
What tl.dot actually compiles to
The primitive looks unassuming:
acc = tl.dot(a, b) # a is (M, K), b is (K, N), result (M, N)
# or with explicit accumulator:
acc = tl.dot(a, b, acc=acc) # acc += a @ b
On a tensor-core capable GPU and supported dtypes, Triton lowers this to a sequence of mma.sync (Ampere, Ada) or wgmma (Hopper) instructions — the hardware matmul ops. One mma.sync on Ampere computes a 16×8×16 outer product per warp per cycle, hundreds of times faster per cycle than the vector ALU could compute the same FMA chain.
The shape constraints — when lowering succeeds
Tensor cores want specific tile shapes. The dimensions of the operands to tl.dot must be at least 16 in every axis:
tl.dot(a, b) # a: (BM, BK), b: (BK, BN)
# Required: BM ≥ 16, BK ≥ 16, BN ≥ 16
And for best occupancy of the mma unit, you want them to be multiples of the hardware mma shape — typically m16n8k16 on Ampere/Ada (Triton issues these in pairs so people often quote it as "16×16×16"), and m64n[8..256]k16 via wgmma on Hopper. Pick BM ≥ 64 on Hopper for clean wgmma lowering.
| GPU | Native mma shape (M × N × K) | Triton calls this |
|---|---|---|
| A100 (sm_80) | 16 × 8 × 16 | mma.sync.aligned.m16n8k16 |
| H100 (sm_90) | 64 × N × 16, N ∈ {8..256} | wgmma.mma_async.sync |
| RTX 4090 (sm_89) | 16 × 8 × 16 | mma.sync.aligned.m16n8k16 |
If your tile dims are too small (e.g. BK=8), Triton silently falls back to the vector ALU — you get correct results but a small fraction of the flops. The cure is a bigger BK.
Dtype rules — bf16 in, fp32 accumulate
The mma instructions take low-precision inputs and accumulate in higher precision:
| Input dtypes | Allowed accumulator | Notes |
|---|---|---|
| bf16 × bf16 | fp32 | The modern default. Use this unless you have a reason not to. |
| fp16 × fp16 | fp32 (or fp16) | fp16 accumulate is faster but can overflow on long-K matmuls. |
| fp8 × fp8 | fp32 | Hopper+. Triton supports e4m3 and e5m2. |
| int8 × int8 | int32 | Hopper+. For quantised inference. |
| tf32 × tf32 | fp32 | fp32 inputs, internally truncated to TF32 (19-bit: 8-bit exponent + 10-bit mantissa). The default for fp32 matmul on Ampere+. |
Concretely:
acc = tl.zeros((BM, BN), dtype=tl.float32) # ← fp32 accumulator
for k in range(0, K, BK):
a = tl.load(...).to(tl.bfloat16) # ← bf16 inputs
b = tl.load(...).to(tl.bfloat16)
acc = tl.dot(a, b, acc) # mma.sync emits bf16×bf16→fp32
y = acc.to(tl.bfloat16) # cast back when storing
tl.zeros(..., dtype=tl.float32). The single most common silent bug in Triton GEMMs.
The K-loop pattern
For matmuls where K is much larger than your tile's BK, you loop over K:
acc = tl.zeros((BM, BN), dtype=tl.float32)
for k in range(0, K, BK):
a = tl.load(a_ptr + offs_m[:, None]*stride_am + (k + offs_k)[None, :]*stride_ak,
mask=(offs_m[:, None] < M) & ((k + offs_k)[None, :] < K), other=0.).to(tl.bfloat16)
b = tl.load(b_ptr + (k + offs_k)[:, None]*stride_bk + offs_n[None, :]*stride_bn,
mask=((k + offs_k)[:, None] < K) & (offs_n[None, :] < N), other=0.).to(tl.bfloat16)
acc = tl.dot(a, b, acc)
Note:
- The K mask uses
other=0.— masked-off K positions contribute nothing to the dot product, exactly what we want. - The accumulator
accstays in registers across loop iterations. The compiler pipelines the next iteration's loads against the current iteration'stl.dot(lesson 13'snum_stages).
Output: cast on store
The accumulator is fp32. The output tensor is typically bf16. Cast at the final store:
tl.store(c_ptr + ..., acc.to(tl.bfloat16), mask=...)
Storing fp32 to a bf16 output without casting will write four bytes per element to a two-byte slot — corruption, plus a runtime error if you're lucky.
When tl.dot falls back to non-tensor-core
If the compiler can't lower to mma, it generates a vector-ALU FMA loop instead. Triggers:
- Tile dim < 16 on any axis. Always use BM, BK, BN ≥ 16.
- Unsupported dtype combo on the target GPU (e.g. fp8 on Ampere).
- Strided / non-contiguous operands that can't be brought into a tensor-core-compatible layout.
- Awkward shapes: extremely small N (≤ 8) on Hopper can fall back from
wgmmatomma.
The compiler doesn't error out; it silently gives you slow code. The way to confirm tensor cores are firing is to dump the PTX (lesson 14) and grep for mma.sync / wgmma. Or just compare throughput to the cuBLAS reference — if you're 10× slower, you're on the ALU.
Interactive · is your tile shape tensor-core-friendly?
Set the tile dims and target GPU. See whether the lowering will hit tensor cores and what mma shape it dices into.
What's next
You can do dense matmuls. The other half of every transformer kernel is reductions — sums, maxes, and the softmax pattern. Lesson 06 introduces tl.sum and friends, and previews the online-softmax recurrence that makes Flash Attention possible.