all_lessons/gpu_kernel_serving/23 · torch.compilelesson 23 / 24

torch.compile: compiler-driven fusion

Three components — Dynamo (capture), AOTAutograd (trace forward+backward), Inductor (codegen) — turn your eager Python into fused Triton + CUDA-graph-captured kernels. Most of the time it just works. The other times you need to read the chain. This lesson teaches both.

The question this lesson answers

You read lesson 22 and you can write fused Triton. But you have 50 ops in a model and can't fuse them by hand. torch.compile does the fusion automatically. When does it work? When does it fall back to eager? And when does the fallback silently undo your speedup?

The three components

1 · DynamoCPython frame evalPython → FX graphgraph breaks on guards 2 · AOTAutograddecompose fwd opstrace backward aheadfunctionalize side effects 3 · Inductorfuse pointwise/reduce/contigemit Triton (GPU) or C++ (CPU)capture CUDA graph runtime cachekey: shapes / dtypes / controlhit → replay graphmiss → recompile Each step has a way to fall back Dynamo can graph-break and resume eager. AOTAutograd can fall back to eager for unsupported ops. Inductor can fall back to ATen kernels for ops it can't lower. None of these regress correctness — they just give up performance. What you get when it works Pointwise chains fuse into 1 Triton kernel. GEMM epilogues fuse (bias + activation). The full forward + backward replays from a CUDA graph. 80–200 kernels per layer collapses to a handful.

Three terms used throughout this lesson

The minimum you need to know

import torch

model = MyModel().cuda()
opt = torch.optim.AdamW(model.parameters(), lr=1e-4)

# one call. defaults are usually right.
model = torch.compile(model)   # mode="default" — Dynamo + Inductor + autograd

for batch in loader:
    out = model(batch)         # first call(s): compile. Later: replay.
    loss = loss_fn(out, batch['y'])
    loss.backward()
    opt.step()
    opt.zero_grad()

Three modes you'll see:

What gets fused, what doesn't

PatternFusion resultNotes
Pointwise chain (add, mul, gelu)Single Triton kernel.The classic win. RMSNorm, residual add, activations.
Pointwise + reduction (sum, max)Often a fused kernel.Inductor recognizes the template.
GEMM + activationcuBLAS epilogue fuses; or a separate Triton elementwise after.Inductor doesn't replace cuBLAS for the matmul itself.
Attention (sdpa)Calls FlashAttention or efficient SDPA backend.The backend itself is hand-tuned; Inductor just routes to it.
Data-dependent control flowGraph break.The chunk before the break compiles; the rest is eager.
Custom CUDA / Triton opTreated as an opaque op (no fusion across it).Register it with @torch.library.custom_op to integrate.
Mixed Python (numpy, third-party)Graph break.Move out of the compiled region or rewrite in torch.

Graph breaks — what they are and how to spot them

A graph break happens when Dynamo can't statically capture a frame: data-dependent branch, unsupported op, side effect, mutating a non-tensor global, etc. Dynamo stops the current graph, compiles it, hands control back to the Python interpreter for the un-traceable part, and starts a new graph after. Two consequences:

# show graph breaks while compiling
import torch._dynamo as dynamo
dynamo.config.verbose = True                     # log decisions to stderr
torch._dynamo.explain(model, sample_input)       # static report; inspect .graph_break_count
TORCH_LOGS="graph_breaks,recompiles" python train.py   # at runtime — these are env-var
                                                       # logging channels: print every graph
                                                       # break / recompile event from Dynamo

Common causes you can fix:

Dynamic shapes

By default torch.compile treats shapes as dynamic after the first variation it sees. That avoids endless recompiles but loses some specialization. Three knobs:

For LLM serving, sequence length is the main moving axis. Padding shapes up to a small set of "buckets" (e.g., 128, 256, 512, 1024) and using mode="reduce-overhead" gives you graph capture across a manageable number of bucket sizes.

Production trap
The first call after deploy compiles for seconds to minutes (especially under max-autotune). Latency-sensitive services must warmup with representative shapes before serving traffic, or the first user requests time out. Cache the Inductor output (it survives across runs) but plan for cold-start anyway after rolling restarts.

Compile cost & cache

CostWhere it livesHow to manage
First-call latencyCompilation runs synchronously the first time a shape is seen.Warmup with representative shapes before timing or serving.
Inductor on-disk cache~/.cache/torch/inductorSurvives across runs. Clear when bumping torch versions or kernel sources.
Recompile stormsEvery shape miss triggers a recompile.Reduce shape variety (bucketing) or set dynamic=True.
GPU memory headroomCUDA graph capture pins buffers.Use torch.compiler.cudagraph_mark_step_begin() at step boundaries.

Verifying the win, not assuming it

Three checks before claiming a compile win:

  1. Same outputs. Run eager and compiled on the same input; compare with torch.allclose at expected tolerance.
  2. Fewer kernels per step. Profile (lesson 20) and confirm the kernel count dropped, not just the wall clock.
  3. Steady state, not first call. Throw away the first few iterations. Compare median of the next 50.
def bench(fn, iters=100):
    for _ in range(10): fn()                # warmup
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters): fn()
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) / iters * 1000  # ms

eager_ms    = bench(lambda: model(x))
model_c     = torch.compile(model)
compiled_ms = bench(lambda: model_c(x))
print(f"eager={eager_ms:.2f}  compiled={compiled_ms:.2f}  speedup={eager_ms/compiled_ms:.2f}x")

torch.compile vs hand-Triton — when does each win

ScenarioBetter choiceWhy
Long pointwise chain (norm + residual + activation)torch.compileInductor already knows this pattern. Free.
Novel layer not in PyTorchHand-Triton + register as custom opInductor needs to know about the op to fuse around it.
Tight loop with many small tensorstorch.compile + reduce-overheadCUDA graphs erase the launch overhead.
Backward of a tricky kernelHand-Triton with explicit autogradYou control numerics; Inductor's autograd handling is good but not magic.
Very dynamic shapesEager or carefully marked dynamic compileRecompile cost can swamp wins.
Already library-bound (attention, GEMM)Either — compile just calls the libraryYou're not in compiler-fusion territory.

Debugging when speedup is "1.0×"

  1. Run with TORCH_LOGS="recompiles". If you see many recompiles, you have a dynamic-shape problem.
  2. Run with TORCH_LOGS="graph_breaks". If many breaks, you're effectively still eager.
  3. Profile (lesson 20) and compare kernel counts. If unchanged, fusion didn't happen — most ops are calling out to libraries already.
  4. Check the regions: torch._dynamo.explain(model, x) prints a tree of what compiled vs broke.
  5. If reduce-overhead didn't help, check that shapes are stable. Even one dynamic dim disables graph capture.

Interactive · compile speedup estimator

How much should torch.compile help your model? Depends on launch-bound share, fusion potential, and how many graph breaks you have. The widget gives a rough estimate.

Predicted compile speedup

Tiny pointwise-heavy models with stable shapes get the biggest wins. Library-bound models (large GEMMs) barely move. Many graph breaks erase the win.

What this gives you for the next lesson

You now have three tools to make code faster: hand-tuned Triton (22), torch.compile (this lesson), and the profiler that tells you which one to reach for (20–21). The final lesson stitches everything into PyTorch-level performance patterns — memory format, allocator pressure, async data movement, mixed precision — that hold whether or not you compile.