Fused linear + bias + GELU
Eager PyTorch evaluates y = gelu(x @ w + b) as three kernels and three HBM round-trips. One Triton kernel folds + b and gelu into the GEMM's epilogue and the round-trips disappear. The lesson formalises why epilogue fusion is the single highest-ROI Triton pattern.
The eager path — three launches, two extra HBM round-trips
The fused path — one kernel, no extra HBM
In a Triton GEMM kernel, the accumulator acc lives in registers across the K-loop. Once the loop ends, acc already holds the matmul result for the output tile — but we haven't written it yet. That's the moment to fold in the epilogue:
acc = tl.zeros((BM, BN), dtype=tl.float32)
for k in range(0, K, BK):
a = tl.load(...).to(tl.bfloat16)
b = tl.load(...).to(tl.bfloat16)
acc = tl.dot(a, b, acc)
# === FUSED EPILOGUE — runs while acc is still in registers ===
bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.).to(tl.float32)
z = acc + bias[None, :]
c = 0.7978845608 # sqrt(2/π)
y = 0.5 * z * (1.0 + tl.tanh(c * (z + 0.044715 * z*z*z)))
# ==============================================================
tl.store(c_ptr + ..., y.to(tl.bfloat16), mask=...)
The bias read and the GELU math both happen in registers. The only HBM traffic is the original matmul reads and one write of y. No z intermediate ever touches memory.
The full kernel
import triton, triton.language as tl, torch
@triton.autotune(
configs=[
triton.Config({'BM':128,'BN':128,'BK':32}, num_warps=4, num_stages=3),
triton.Config({'BM':128,'BN':256,'BK':32}, num_warps=8, num_stages=3),
triton.Config({'BM':64, 'BN':64, 'BK':64}, num_warps=4, num_stages=4),
],
key=['M','N','K'],
)
@triton.jit
def gelu_linear_kernel(
x_ptr, w_ptr, b_ptr, y_ptr,
M, N, K,
sxm, sxk, swk, swn, sym, syn,
BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BM + tl.arange(0, BM)
offs_n = pid_n * BN + tl.arange(0, BN)
offs_k = tl.arange(0, BK)
acc = tl.zeros((BM, BN), dtype=tl.float32)
for k in range(0, K, BK):
a = tl.load(x_ptr + offs_m[:, None]*sxm + (k+offs_k)[None, :]*sxk,
mask=(offs_m[:, None] < M) & ((k+offs_k)[None, :] < K),
other=0.).to(tl.bfloat16)
bw = tl.load(w_ptr + (k+offs_k)[:, None]*swk + offs_n[None, :]*swn,
mask=((k+offs_k)[:, None] < K) & (offs_n[None, :] < N),
other=0.).to(tl.bfloat16)
acc = tl.dot(a, bw, acc)
bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.).to(tl.float32)
z = acc + bias[None, :]
c = 0.7978845608 # sqrt(2/π)
y = 0.5 * z * (1.0 + tl.tanh(c * (z + 0.044715 * z*z*z)))
tl.store(y_ptr + offs_m[:, None]*sym + offs_n[None, :]*syn,
y.to(tl.bfloat16),
mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
def gelu_linear(x, w, b):
M, K = x.shape; K2, N = w.shape
assert K == K2
y = torch.empty(M, N, device=x.device, dtype=x.dtype)
grid = lambda meta: (triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
gelu_linear_kernel[grid](
x, w, b, y, M, N, K,
x.stride(0), x.stride(1),
w.stride(0), w.stride(1),
y.stride(0), y.stride(1),
)
return y
Why this is the highest-ROI pattern
Three observations that make epilogue fusion the most reliable Triton win:
- Modern transformers are bandwidth-bound at low batch. Every elementwise op that touches a model-sized tensor takes ~1 HBM round-trip. Folding them into the matmul saves real wall-clock.
- The math fits. GELU, bias-add, dropout mask, residual add — all are O(output_size) operations that compose trivially as register-only math after the dot product.
- cuBLAS / cuBLASLt have limited epilogue support. They cover bias-add and relu, but not GELU + dropout + scale + residual together. Triton can fuse the whole stack.
The arithmetic intensity, before vs after
For M = K = N = 4096 in bf16 (each of x, w, y, z is 32 MB):
| Variant | HBM bytes | Flops | Intensity |
|---|---|---|---|
| Eager (3 launches) | x + w + z_write + z_read + z_write + z_read + y_write = 7·32 = 224 MB | 2·4K³ ≈ 137 GFLOP | ~610 flop/byte |
| Fused (1 kernel) | x + w + y = 96 MB | 2·4K³ ≈ 137 GFLOP | ~1430 flop/byte |
Intensity more than doubles. At 4K shapes the matmul itself is compute-bound (intensity > the knee at ~150 flop/byte for bf16 on H100), so the eliminated bias-add and gelu HBM round-trips are pure savings on top — typically 10–20% of step time for transformer blocks at common batch sizes.
What you fuse next
The pattern generalises. Common epilogues to fold into a matmul:
- Bias + activation (this lesson).
- Bias + activation + scale (output of a quantised path).
- Bias + dropout mask (training kernels — dropout RNG state lives in the kernel).
- Bias + residual add (a transformer block's
output + residual). - Quant + dequant (output goes to bf16 from fp32 acc; weight read dequantises from int4).
If you find yourself writing two consecutive elementwise PyTorch ops after a matmul, ask: can I fuse them into the matmul epilogue? The answer is almost always yes.
Interactive · the bandwidth saving
Adjust the GEMM shape; compare eager vs fused HBM traffic and predicted wall-clock on an H100.
What's next
We've used the K-loop without explaining it. Lesson 09 zooms in on the matmul itself — the canonical tiled GEMM in Triton — including the boundary masks per axis and exactly where cuBLAS still beats hand-Triton.