torch.compile: compiler-driven fusion
Three components — Dynamo (capture), AOTAutograd (trace forward+backward), Inductor (codegen) — turn your eager Python into fused Triton + CUDA-graph-captured kernels. Most of the time it just works. The other times you need to read the chain. This lesson teaches both.
The question this lesson answers
You read lesson 22 and you can write fused Triton. But you have 50 ops in a model and can't fuse them by hand. torch.compile does the fusion automatically. When does it work? When does it fall back to eager? And when does the fallback silently undo your speedup?
The three components
Three terms used throughout this lesson
- FX graph — PyTorch's symbolic IR. A flat list of
(op, args, kwargs, output)nodes plus tensor metadata. Dynamo's output, Inductor's input. - Functionalize — rewrite in-place mutations (e.g.,
x.add_(y),x.relu_()) as pure ops that return a new tensor, so the graph is referentially transparent. AOTAutograd does this. - Decompose — replace a high-level op (e.g.,
F.layer_norm) with its primitive components (mean, var, divide, scale, bias). Inductor needs primitives so it can fuse across boundaries.
The minimum you need to know
import torch
model = MyModel().cuda()
opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
# one call. defaults are usually right.
model = torch.compile(model) # mode="default" — Dynamo + Inductor + autograd
for batch in loader:
out = model(batch) # first call(s): compile. Later: replay.
loss = loss_fn(out, batch['y'])
loss.backward()
opt.step()
opt.zero_grad()
Three modes you'll see:
mode="default"— balanced. Use this first.mode="reduce-overhead"— also captures CUDA graphs. Best for small-batch decode; needs stable shapes.mode="max-autotune"— exhaustive Inductor autotune; can take minutes to compile but is the fastest at run time.
What gets fused, what doesn't
| Pattern | Fusion result | Notes |
|---|---|---|
| Pointwise chain (add, mul, gelu) | Single Triton kernel. | The classic win. RMSNorm, residual add, activations. |
| Pointwise + reduction (sum, max) | Often a fused kernel. | Inductor recognizes the template. |
| GEMM + activation | cuBLAS epilogue fuses; or a separate Triton elementwise after. | Inductor doesn't replace cuBLAS for the matmul itself. |
| Attention (sdpa) | Calls FlashAttention or efficient SDPA backend. | The backend itself is hand-tuned; Inductor just routes to it. |
| Data-dependent control flow | Graph break. | The chunk before the break compiles; the rest is eager. |
| Custom CUDA / Triton op | Treated as an opaque op (no fusion across it). | Register it with @torch.library.custom_op to integrate. |
| Mixed Python (numpy, third-party) | Graph break. | Move out of the compiled region or rewrite in torch. |
Graph breaks — what they are and how to spot them
A graph break happens when Dynamo can't statically capture a frame: data-dependent branch, unsupported op, side effect, mutating a non-tensor global, etc. Dynamo stops the current graph, compiles it, hands control back to the Python interpreter for the un-traceable part, and starts a new graph after. Two consequences:
- Fusion stops at the break. Pointwise ops on either side won't fuse with each other.
- CUDA graph capture stops at the break. So
reduce-overheadmode loses its win.
# show graph breaks while compiling
import torch._dynamo as dynamo
dynamo.config.verbose = True # log decisions to stderr
torch._dynamo.explain(model, sample_input) # static report; inspect .graph_break_count
TORCH_LOGS="graph_breaks,recompiles" python train.py # at runtime — these are env-var
# logging channels: print every graph
# break / recompile event from Dynamo
Common causes you can fix:
print(...)/logger.info(...)on a tensor. Move outside compiled region or call.item()deliberately at a sync point.if x.sum() > 0:. Data-dependent. Restructure into a masked op or accept the break.tensor.to('cpu')mid-forward. Move data movement out.- List/dict mutation. Make tensors flow through return values instead.
- Calling a non-torch library (e.g.,
numpy.linalg). Replace with torch equivalent.
Dynamic shapes
By default torch.compile treats shapes as dynamic after the first variation it sees. That avoids endless recompiles but loses some specialization. Three knobs:
torch.compile(model, dynamic=True)— opt all dimensions into dynamic from the start.torch.compile(model, dynamic=False)— specialize on every shape; one compile per shape (good if shapes are few and stable).- Default — first call specializes, second call with different shape promotes to dynamic.
For LLM serving, sequence length is the main moving axis. Padding shapes up to a small set of "buckets" (e.g., 128, 256, 512, 1024) and using mode="reduce-overhead" gives you graph capture across a manageable number of bucket sizes.
max-autotune). Latency-sensitive services must warmup with representative shapes before serving traffic, or the first user requests time out. Cache the Inductor output (it survives across runs) but plan for cold-start anyway after rolling restarts.Compile cost & cache
| Cost | Where it lives | How to manage |
|---|---|---|
| First-call latency | Compilation runs synchronously the first time a shape is seen. | Warmup with representative shapes before timing or serving. |
| Inductor on-disk cache | ~/.cache/torch/inductor | Survives across runs. Clear when bumping torch versions or kernel sources. |
| Recompile storms | Every shape miss triggers a recompile. | Reduce shape variety (bucketing) or set dynamic=True. |
| GPU memory headroom | CUDA graph capture pins buffers. | Use torch.compiler.cudagraph_mark_step_begin() at step boundaries. |
Verifying the win, not assuming it
Three checks before claiming a compile win:
- Same outputs. Run eager and compiled on the same input; compare with
torch.allcloseat expected tolerance. - Fewer kernels per step. Profile (lesson 20) and confirm the kernel count dropped, not just the wall clock.
- Steady state, not first call. Throw away the first few iterations. Compare median of the next 50.
def bench(fn, iters=100):
for _ in range(10): fn() # warmup
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(iters): fn()
torch.cuda.synchronize()
return (time.perf_counter() - t0) / iters * 1000 # ms
eager_ms = bench(lambda: model(x))
model_c = torch.compile(model)
compiled_ms = bench(lambda: model_c(x))
print(f"eager={eager_ms:.2f} compiled={compiled_ms:.2f} speedup={eager_ms/compiled_ms:.2f}x")
torch.compile vs hand-Triton — when does each win
| Scenario | Better choice | Why |
|---|---|---|
| Long pointwise chain (norm + residual + activation) | torch.compile | Inductor already knows this pattern. Free. |
| Novel layer not in PyTorch | Hand-Triton + register as custom op | Inductor needs to know about the op to fuse around it. |
| Tight loop with many small tensors | torch.compile + reduce-overhead | CUDA graphs erase the launch overhead. |
| Backward of a tricky kernel | Hand-Triton with explicit autograd | You control numerics; Inductor's autograd handling is good but not magic. |
| Very dynamic shapes | Eager or carefully marked dynamic compile | Recompile cost can swamp wins. |
| Already library-bound (attention, GEMM) | Either — compile just calls the library | You're not in compiler-fusion territory. |
Debugging when speedup is "1.0×"
- Run with
TORCH_LOGS="recompiles". If you see many recompiles, you have a dynamic-shape problem. - Run with
TORCH_LOGS="graph_breaks". If many breaks, you're effectively still eager. - Profile (lesson 20) and compare kernel counts. If unchanged, fusion didn't happen — most ops are calling out to libraries already.
- Check the regions:
torch._dynamo.explain(model, x)prints a tree of what compiled vs broke. - If
reduce-overheaddidn't help, check that shapes are stable. Even one dynamic dim disables graph capture.
Interactive · compile speedup estimator
How much should torch.compile help your model? Depends on launch-bound share, fusion potential, and how many graph breaks you have. The widget gives a rough estimate.
What this gives you for the next lesson
You now have three tools to make code faster: hand-tuned Triton (22), torch.compile (this lesson), and the profiler that tells you which one to reach for (20–21). The final lesson stitches everything into PyTorch-level performance patterns — memory format, allocator pressure, async data movement, mixed precision — that hold whether or not you compile.