Shared memory & tiled matmul
A naïve matmul is memory-bound because each operand is read O(N) times from HBM. Tile A and B into shared memory and the same byte gets read once from HBM, reused N times from SMEM. The HBM traffic falls by ~N, the kernel becomes compute-bound, and the GPU's tensor cores get to do work.
The naïve kernel and why it's slow
Matrix multiply: C = A · B with shapes (M, K) × (K, N) → (M, N). The textbook one-thread-per-output kernel:
__global__ void matmul_naive(const float* A, const float* B, float* C,
int M, int N, int K) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float acc = 0;
for (int k = 0; k < K; k++) {
acc += A[row*K + k] * B[k*N + col];
}
C[row*N + col] = acc;
}
}
Count the HBM reads. Each output element C[i,j] requires K reads from A (the i-th row) and K reads from B (the j-th column). So total reads:
For M = N = K = 1024, that's 2 · 10⁹ reads = 8 GB of HBM traffic per matmul. The actual matmul does M · N · K · 2 = 2 · 10⁹ FLOPs — so the arithmetic intensity is 0.25 FLOP/byte. Far below the H100 ridge of ~150 FLOP/byte. The kernel is memory-bound; it achieves ~3% of peak FLOPS.
The reuse insight
Look at the loads again. Thread (i, 0) reads A[i, 0..K-1]. Thread (i, 1) also reads A[i, 0..K-1]. So does (i, 2), …, (i, N-1). The same N values read from HBM N times. Total redundant reads: M · N · K · 2. We're paying HBM bandwidth for cached data.
The fix: load each value of A and B once from HBM into SMEM, then have all the threads that need it read from SMEM (which is ~6× faster and on-chip).
The tiled mechanic
Pick a tile size T (commonly 16, 32, or 64). The kernel computes one T × T tile of C per block, with one thread per output element. Inside the block:
- Phase A: cooperatively load a T × T chunk of A and a T × T chunk of B into SMEM (each thread loads one element of each).
- Synchronise (
__syncthreads()). - Phase B: each thread accumulates its partial dot product over the K-dimension using the SMEM tiles.
- Synchronise.
- Move along the K dimension to the next pair of tiles and repeat.
- After K/T iterations, each thread has its complete C[i, j]; write it out.
The code:
#define T 16
__global__ void matmul_tiled(const float* A, const float* B, float* C,
int M, int N, int K) {
__shared__ float As[T][T];
__shared__ float Bs[T][T];
int row = blockIdx.y * T + threadIdx.y;
int col = blockIdx.x * T + threadIdx.x;
float acc = 0;
for (int kBlock = 0; kBlock < K; kBlock += T) {
// cooperatively load tile
As[threadIdx.y][threadIdx.x] = A[row*K + (kBlock + threadIdx.x)];
Bs[threadIdx.y][threadIdx.x] = B[(kBlock + threadIdx.y)*N + col];
__syncthreads();
// accumulate using SMEM
for (int k = 0; k < T; k++)
acc += As[threadIdx.y][k] * Bs[k][threadIdx.x];
__syncthreads();
}
C[row*N + col] = acc;
}
Animated · tile-by-tile matmul, step through K
One block computes one T × T tile of C. Press play to step along the K dimension: A tile and B tile load into SMEM, threads accumulate partial sums into registers, sync, move to next K. The SMEM reuse counter ticks up — each loaded byte is used T times.
2D · output tile assignment
The output M × N matrix is partitioned into T × T tiles, one per thread block. Click a tile to see which threads in the block handle which output elements. Each block contains T² threads; thread (ty, tx) computes C[blockRow·T + ty][blockCol·T + tx].
2D · arithmetic intensity slider on the roofline
Tile size T sets arithmetic intensity to T/4 FLOP/B (fp32). Slide T and watch the kernel walk up the roofline. The crossover into compute-bound happens around T=32 on H100 fp32 (ridge ≈ 20 F/B) and around T=600 if you target the bf16 tensor-core peak (ridge ≈ 295 F/B — only reached with register tiling + wmma).
The accounting — how the win arises
Per block (computing one T×T C tile):
- HBM reads: (M-tile + B-tile) per K-step × K/T iterations = T·T·2 · K/T = 2 · T · K elements per block
- FLOPs: T·T · K · 2 = 2 · T² · K per block
- Arithmetic intensity: 2 · T² · K / (2 · T · K · 4) = T / 4 FLOP/byte (for fp32)
So with T=16, intensity is 4 FLOP/byte; with T=32, intensity is 8 FLOP/byte; with T=64, intensity is 16 FLOP/byte. Compared to the naïve 0.25 FLOP/byte, tiling pushes you 16× to 64× higher up the roofline (4/0.25, 8/0.25, 16/0.25 = 16×, 32×, 64× respectively). T=32 already lands you near compute-bound on most kernels.
Why doesn't increasing T forever help? Because T² × 4 bytes (for the A tile) plus T² × 4 (for the B tile) has to fit in the block's SMEM allocation. At T=64 and fp32 that's 2 · 64² · 4 = 32 KB — a sizeable chunk of the 228 KB available. And the number of threads per block is T², which with T=64 = 4096 — that exceeds the 1024 limit. Real kernels use smaller tiles but compute multiple output elements per thread (lesson 27 on register tiling).
Why __syncthreads is non-negotiable
Threads load each other's values from SMEM. Without a barrier, thread 0 might start consuming SMEM cell [0][0] before thread 17 finished writing it. __syncthreads() is a block-level barrier that waits for every thread in the block to reach the barrier before any thread proceeds. Skipping it gives you race conditions that often look like "the result is wrong by a small bit-pattern-dependent amount" — the worst kind of bug.
The two sync points in the tiled kernel are essential:
- After loading the tile (line 18) — before any thread reads from SMEM.
- After the inner loop (line 23) — before any thread overwrites the SMEM tile in the next iteration.
SMEM bank conflicts — the next-level concern
SMEM is divided into 32 banks. A warp's 32 threads simultaneously accessing 32 different banks → all complete in 1 cycle. If two threads hit the same bank, they serialise. The pattern As[threadIdx.y][threadIdx.x] in the tile loop has threadIdx.x varying across a warp, hitting different SMEM columns → different banks → no conflict.
But the load Bs[threadIdx.y][threadIdx.x] in the inner loop (where threadIdx.y is the row, k is the inner index) does: Bs[k][threadIdx.x]. For a fixed k across a warp, threadIdx.x varies → different banks → no conflict. Good.
The classic bank-conflict trap is the transpose pattern: Bs[threadIdx.x][k]. With T=32 and 32 banks, every thread reads the same column → all 32 threads hit the same bank → 32× slowdown. The cure: pad the tile by one: __shared__ float Bs[T][T+1]. This shifts each row so the strided access no longer aligns to one bank.
Where this lives in real life
The tiled-matmul kernel above gets you within 2–3× of cuBLAS on simple sizes — already a meaningful speedup over the naïve. Production matmul kernels (cuBLAS, CUTLASS) layer on more optimisations:
- Multiple outputs per thread (register tiling). Each thread computes a 4×4 or 8×8 block of C in its registers, getting more FLOPs per SMEM load.
- Double buffering. While processing the current tile, load the next tile from HBM into a second SMEM buffer in parallel.
- Tensor cores. Instead of
acc += a*b, issue anmma.syncthat does a 16×16×16 sub-matmul in one cycle. Lesson 28. - Async copies.
cp.asyncon Ampere+ allows SMEM loads that don't block the thread, enabling more aggressive overlap.
And one giant practical truth: don't write production matmul yourself. Use cuBLAS or CUTLASS. Write your own only for fused variants where the library can't help.
Why tiling is also the secret behind FlashAttention
FlashAttention (lesson 16) is a tiled matmul too — Q is the "left" matrix, K is the "right", and the inner loop happens to do softmax-fused dot products instead of plain ones. The whole point of why it works: the tile of Q is reused over many tiles of K and V, in SMEM. The fusion (no materialised QK^T matrix) is the headline; the tiling is the foundation.
Interactive · count HBM reads as you tile
Slide the matrix size and tile size. The widget computes naïve HBM reads vs tiled HBM reads, and the resulting arithmetic intensity. Watch the intensity climb linearly with tile size — and the SMEM usage grow as T².
__syncthreads is the price of cooperation.