gpu_kernel_serving / 05 · tiled matmul lesson 5 / 17

Shared memory & tiled matmul

A naïve matmul is memory-bound because each operand is read O(N) times from HBM. Tile A and B into shared memory and the same byte gets read once from HBM, reused N times from SMEM. The HBM traffic falls by ~N, the kernel becomes compute-bound, and the GPU's tensor cores get to do work.

The naïve kernel and why it's slow

Matrix multiply: C = A · B with shapes (M, K) × (K, N) → (M, N). The textbook one-thread-per-output kernel:

__global__ void matmul_naive(const float* A, const float* B, float* C,
                              int M, int N, int K) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    if (row < M && col < N) {
        float acc = 0;
        for (int k = 0; k < K; k++) {
            acc += A[row*K + k] * B[k*N + col];
        }
        C[row*N + col] = acc;
    }
}

Count the HBM reads. Each output element C[i,j] requires K reads from A (the i-th row) and K reads from B (the j-th column). So total reads:

HBM_reads = M · N · K · 2 (one read of A, one of B, per output element)

For M = N = K = 1024, that's 2 · 10⁹ reads = 8 GB of HBM traffic per matmul. The actual matmul does M · N · K · 2 = 2 · 10⁹ FLOPs — so the arithmetic intensity is 0.25 FLOP/byte. Far below the H100 ridge of ~150 FLOP/byte. The kernel is memory-bound; it achieves ~3% of peak FLOPS.

The reuse insight

Look at the loads again. Thread (i, 0) reads A[i, 0..K-1]. Thread (i, 1) also reads A[i, 0..K-1]. So does (i, 2), …, (i, N-1). The same N values read from HBM N times. Total redundant reads: M · N · K · 2. We're paying HBM bandwidth for cached data.

The fix: load each value of A and B once from HBM into SMEM, then have all the threads that need it read from SMEM (which is ~6× faster and on-chip).

The tiled mechanic

Pick a tile size T (commonly 16, 32, or 64). The kernel computes one T × T tile of C per block, with one thread per output element. Inside the block:

  1. Phase A: cooperatively load a T × T chunk of A and a T × T chunk of B into SMEM (each thread loads one element of each).
  2. Synchronise (__syncthreads()).
  3. Phase B: each thread accumulates its partial dot product over the K-dimension using the SMEM tiles.
  4. Synchronise.
  5. Move along the K dimension to the next pair of tiles and repeat.
  6. After K/T iterations, each thread has its complete C[i, j]; write it out.

The code:

#define T 16
__global__ void matmul_tiled(const float* A, const float* B, float* C,
                              int M, int N, int K) {
    __shared__ float As[T][T];
    __shared__ float Bs[T][T];

    int row = blockIdx.y * T + threadIdx.y;
    int col = blockIdx.x * T + threadIdx.x;
    float acc = 0;

    for (int kBlock = 0; kBlock < K; kBlock += T) {
        // cooperatively load tile
        As[threadIdx.y][threadIdx.x] = A[row*K + (kBlock + threadIdx.x)];
        Bs[threadIdx.y][threadIdx.x] = B[(kBlock + threadIdx.y)*N + col];
        __syncthreads();

        // accumulate using SMEM
        for (int k = 0; k < T; k++)
            acc += As[threadIdx.y][k] * Bs[k][threadIdx.x];
        __syncthreads();
    }
    C[row*N + col] = acc;
}
block computes a T×T tile of C; iterates A and B tiles along K A · M×K A tile B · K×N B tile C · M×N C tile A tile (in SMEM) reused across the width of B (T elements per output); B tile (in SMEM) reused across the height of A

Animated · tile-by-tile matmul, step through K

One block computes one T × T tile of C. Press play to step along the K dimension: A tile and B tile load into SMEM, threads accumulate partial sums into registers, sync, move to next K. The SMEM reuse counter ticks up — each loaded byte is used T times.

Tiled matmul · cooperative SMEM loads, K-loop accumulate
Left: A (highlighted row of tiles). Middle: B (highlighted column of tiles). Right: C tile being computed (one block, T×T threads, partial sums in registers). Bottom strip = phase clock.
phase
K iteration
SMEM reuse
accumulated FLOPs

2D · output tile assignment

The output M × N matrix is partitioned into T × T tiles, one per thread block. Click a tile to see which threads in the block handle which output elements. Each block contains threads; thread (ty, tx) computes C[blockRow·T + ty][blockCol·T + tx].

Output tile mapping · which thread does what
Left: tiled output grid. Right: zoom-in of clicked tile showing thread (ty, tx) indices. Each cell = one thread's output.
tiles total
selected tile
threads/block
block grid

2D · arithmetic intensity slider on the roofline

Tile size T sets arithmetic intensity to T/4 FLOP/B (fp32). Slide T and watch the kernel walk up the roofline. The crossover into compute-bound happens around T=32 on H100 fp32 (ridge ≈ 20 F/B) and around T=600 if you target the bf16 tensor-core peak (ridge ≈ 295 F/B — only reached with register tiling + wmma).

Tile size → arithmetic intensity → roofline position
Single matmul, N×N. AI grows linearly in T. SMEM usage grows quadratically (2T² × 4 B), bounded by ~228 KB. Threads/block = T² bounded by 1024.
AI (F/B)
SMEM/block
threads/block
verdict

The accounting — how the win arises

Per block (computing one T×T C tile):

So with T=16, intensity is 4 FLOP/byte; with T=32, intensity is 8 FLOP/byte; with T=64, intensity is 16 FLOP/byte. Compared to the naïve 0.25 FLOP/byte, tiling pushes you 16× to 64× higher up the roofline (4/0.25, 8/0.25, 16/0.25 = 16×, 32×, 64× respectively). T=32 already lands you near compute-bound on most kernels.

Why doesn't increasing T forever help? Because T² × 4 bytes (for the A tile) plus T² × 4 (for the B tile) has to fit in the block's SMEM allocation. At T=64 and fp32 that's 2 · 64² · 4 = 32 KB — a sizeable chunk of the 228 KB available. And the number of threads per block is , which with T=64 = 4096 — that exceeds the 1024 limit. Real kernels use smaller tiles but compute multiple output elements per thread (lesson 27 on register tiling).

Why __syncthreads is non-negotiable

Threads load each other's values from SMEM. Without a barrier, thread 0 might start consuming SMEM cell [0][0] before thread 17 finished writing it. __syncthreads() is a block-level barrier that waits for every thread in the block to reach the barrier before any thread proceeds. Skipping it gives you race conditions that often look like "the result is wrong by a small bit-pattern-dependent amount" — the worst kind of bug.

The two sync points in the tiled kernel are essential:

SMEM bank conflicts — the next-level concern

SMEM is divided into 32 banks. A warp's 32 threads simultaneously accessing 32 different banks → all complete in 1 cycle. If two threads hit the same bank, they serialise. The pattern As[threadIdx.y][threadIdx.x] in the tile loop has threadIdx.x varying across a warp, hitting different SMEM columns → different banks → no conflict.

But the load Bs[threadIdx.y][threadIdx.x] in the inner loop (where threadIdx.y is the row, k is the inner index) does: Bs[k][threadIdx.x]. For a fixed k across a warp, threadIdx.x varies → different banks → no conflict. Good.

The classic bank-conflict trap is the transpose pattern: Bs[threadIdx.x][k]. With T=32 and 32 banks, every thread reads the same column → all 32 threads hit the same bank → 32× slowdown. The cure: pad the tile by one: __shared__ float Bs[T][T+1]. This shifts each row so the strided access no longer aligns to one bank.

Where this lives in real life

The tiled-matmul kernel above gets you within 2–3× of cuBLAS on simple sizes — already a meaningful speedup over the naïve. Production matmul kernels (cuBLAS, CUTLASS) layer on more optimisations:

And one giant practical truth: don't write production matmul yourself. Use cuBLAS or CUTLASS. Write your own only for fused variants where the library can't help.

Why tiling is also the secret behind FlashAttention

FlashAttention (lesson 16) is a tiled matmul too — Q is the "left" matrix, K is the "right", and the inner loop happens to do softmax-fused dot products instead of plain ones. The whole point of why it works: the tile of Q is reused over many tiles of K and V, in SMEM. The fusion (no materialised QK^T matrix) is the headline; the tiling is the foundation.

Interactive · count HBM reads as you tile

Slide the matrix size and tile size. The widget computes naïve HBM reads vs tiled HBM reads, and the resulting arithmetic intensity. Watch the intensity climb linearly with tile size — and the SMEM usage grow as T².

Tiled matmul · arithmetic intensity vs tile size
Naïve: every output element reads K elements of A + K of B. Tiled: each tile is loaded once, reused T times. Higher tile size = more reuse = higher intensity = compute-bound regime.
naïve HBM
tiled HBM
naïve intensity
tiled intensity
Takeaway
Tiling is the move from "every operand goes through HBM once per use" to "every operand goes through HBM once". The arithmetic intensity rises linearly with the tile size — until SMEM or register pressure caps you. The same tiling pattern reappears in FlashAttention, FlashConv, every matmul-like kernel in the GPU canon. SMEM is the working set; tiles are the unit of work; __syncthreads is the price of cooperation.