all_lessons / Triton kernels / lessons / 08 · fused activation lesson 08 / 14

Fused linear + bias + GELU

Eager PyTorch evaluates y = gelu(x @ w + b) as three kernels and three HBM round-trips. One Triton kernel folds + b and gelu into the GEMM's epilogue and the round-trips disappear. The lesson formalises why epilogue fusion is the single highest-ROI Triton pattern.

The eager path — three launches, two extra HBM round-trips

Eager PyTorch (3 launches) cuBLAS matmul z = x @ w z (HBM) add kernel z = z + b z (HBM) gelu kernel y = gelu(z) Memory traffic for the elementwise tail: add: read z (write z, write back) → 2 × |z| bytes gelu: read z (write y) → 2 × |z| bytes total elementwise traffic: 4 × |z| bytes (plus the bias read). All of this is pure waste — none of the math required it.

The fused path — one kernel, no extra HBM

In a Triton GEMM kernel, the accumulator acc lives in registers across the K-loop. Once the loop ends, acc already holds the matmul result for the output tile — but we haven't written it yet. That's the moment to fold in the epilogue:

acc = tl.zeros((BM, BN), dtype=tl.float32)
for k in range(0, K, BK):
    a = tl.load(...).to(tl.bfloat16)
    b = tl.load(...).to(tl.bfloat16)
    acc = tl.dot(a, b, acc)

# === FUSED EPILOGUE — runs while acc is still in registers ===
bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.).to(tl.float32)
z    = acc + bias[None, :]
c    = 0.7978845608        # sqrt(2/π)
y    = 0.5 * z * (1.0 + tl.tanh(c * (z + 0.044715 * z*z*z)))
# ==============================================================

tl.store(c_ptr + ..., y.to(tl.bfloat16), mask=...)

The bias read and the GELU math both happen in registers. The only HBM traffic is the original matmul reads and one write of y. No z intermediate ever touches memory.

Triton fused (1 launch) one kernel: tl.dot K-loop → + bias → gelu → store all of (+ b) and gelu happen in registers, then one fp16 store HBM traffic: matmul reads only + 1 store. The bias is tiny (N floats) and stays in L1. Bandwidth saved: ~4 × |z| bytes vs eager. On a 4K×4K bf16 z that's ~128 MB.

The full kernel

import triton, triton.language as tl, torch

@triton.autotune(
    configs=[
        triton.Config({'BM':128,'BN':128,'BK':32}, num_warps=4, num_stages=3),
        triton.Config({'BM':128,'BN':256,'BK':32}, num_warps=8, num_stages=3),
        triton.Config({'BM':64, 'BN':64, 'BK':64}, num_warps=4, num_stages=4),
    ],
    key=['M','N','K'],
)
@triton.jit
def gelu_linear_kernel(
    x_ptr, w_ptr, b_ptr, y_ptr,
    M, N, K,
    sxm, sxk, swk, swn, sym, syn,
    BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    offs_m = pid_m * BM + tl.arange(0, BM)
    offs_n = pid_n * BN + tl.arange(0, BN)
    offs_k = tl.arange(0, BK)

    acc = tl.zeros((BM, BN), dtype=tl.float32)
    for k in range(0, K, BK):
        a = tl.load(x_ptr + offs_m[:, None]*sxm + (k+offs_k)[None, :]*sxk,
                    mask=(offs_m[:, None] < M) & ((k+offs_k)[None, :] < K),
                    other=0.).to(tl.bfloat16)
        bw = tl.load(w_ptr + (k+offs_k)[:, None]*swk + offs_n[None, :]*swn,
                     mask=((k+offs_k)[:, None] < K) & (offs_n[None, :] < N),
                     other=0.).to(tl.bfloat16)
        acc = tl.dot(a, bw, acc)

    bias = tl.load(b_ptr + offs_n, mask=offs_n < N, other=0.).to(tl.float32)
    z = acc + bias[None, :]
    c = 0.7978845608    # sqrt(2/π)
    y = 0.5 * z * (1.0 + tl.tanh(c * (z + 0.044715 * z*z*z)))

    tl.store(y_ptr + offs_m[:, None]*sym + offs_n[None, :]*syn,
             y.to(tl.bfloat16),
             mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))

def gelu_linear(x, w, b):
    M, K = x.shape; K2, N = w.shape
    assert K == K2
    y = torch.empty(M, N, device=x.device, dtype=x.dtype)
    grid = lambda meta: (triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
    gelu_linear_kernel[grid](
        x, w, b, y, M, N, K,
        x.stride(0), x.stride(1),
        w.stride(0), w.stride(1),
        y.stride(0), y.stride(1),
    )
    return y

Why this is the highest-ROI pattern

Three observations that make epilogue fusion the most reliable Triton win:

  1. Modern transformers are bandwidth-bound at low batch. Every elementwise op that touches a model-sized tensor takes ~1 HBM round-trip. Folding them into the matmul saves real wall-clock.
  2. The math fits. GELU, bias-add, dropout mask, residual add — all are O(output_size) operations that compose trivially as register-only math after the dot product.
  3. cuBLAS / cuBLASLt have limited epilogue support. They cover bias-add and relu, but not GELU + dropout + scale + residual together. Triton can fuse the whole stack.

The arithmetic intensity, before vs after

For M = K = N = 4096 in bf16 (each of x, w, y, z is 32 MB):

VariantHBM bytesFlopsIntensity
Eager (3 launches)x + w + z_write + z_read + z_write + z_read + y_write = 7·32 = 224 MB2·4K³ ≈ 137 GFLOP~610 flop/byte
Fused (1 kernel)x + w + y = 96 MB2·4K³ ≈ 137 GFLOP~1430 flop/byte

Intensity more than doubles. At 4K shapes the matmul itself is compute-bound (intensity > the knee at ~150 flop/byte for bf16 on H100), so the eliminated bias-add and gelu HBM round-trips are pure savings on top — typically 10–20% of step time for transformer blocks at common batch sizes.

What you fuse next

The pattern generalises. Common epilogues to fold into a matmul:

If you find yourself writing two consecutive elementwise PyTorch ops after a matmul, ask: can I fuse them into the matmul epilogue? The answer is almost always yes.

Interactive · the bandwidth saving

Adjust the GEMM shape; compare eager vs fused HBM traffic and predicted wall-clock on an H100.

Fusion savings predictor

Memory model: eager moves z three times (matmul write, add read+write, gelu read+write); fused moves it once (the store at the end).

What's next

We've used the K-loop without explaining it. Lesson 09 zooms in on the matmul itself — the canonical tiled GEMM in Triton — including the boundary masks per axis and exactly where cuBLAS still beats hand-Triton.