system_ml / 18 · torch.compile lesson 18 / 19

torch.compile — Dynamo + AOT Autograd + Inductor

Three components, each with its own job. The first captures a graph out of your Python. The second attaches a backward pass. The third generates kernels. Most of "torch.compile is slower than I expected" is one of them silently falling back.

The three-stage pipeline

torch.compile internal stages Dynamo capture FX graph from Python bytecode AOT Autograd build forward + backward graph, decompose into prims Inductor fuse + lower to Triton or C++ for CPU runtime call generated kernels guards (runtime shape/dtype checks) on miss → recompile or fall back to eager

Stage 1 — Dynamo, the Python-bytecode tracer

Dynamo intercepts a function call at the Python bytecode level via CPython's frame-evaluation API (PEP 523, available since CPython 3.6; later versions like 3.11+ extended the hook with more useful per-instruction control). It runs the function's instructions symbolically on "FakeTensor" placeholders, recording operations into an FX graph. The result is two things:

This is genius and constraining at once. Because guards check shapes/dtypes/devices on every call, dynamic Python (data-dependent control flow, runtime-shape branches) breaks the trace — Dynamo emits a graph break: it falls back to eager mode for the unsupported chunk, then resumes tracing afterwards. A function with three graph breaks compiles into three small graphs interleaved with eager.

Graph breaks are the failure mode you'll fight
Common causes: Diagnose with TORCH_LOGS="graph_breaks". Often you can refactor a single .item() away and a 12-graph trace becomes a 1-graph trace.

Animated · Dynamo graph capture, line by line

Watch Dynamo read Python code from left to right, extracting nodes into the FX graph on the right. When it hits a .item() or other unsupported op, the graph breaks: the partial graph is sealed, eager Python runs for a step, then a fresh graph starts. Pick a workload to see different break patterns.

Dynamo trace · graphs forming as bytecode flows
Left pane = Python source, line by line. Right pane = FX graphs being built. Solid line = a captured node; orange break = graph break (eager fallback).
line
graphs so far
nodes captured
graph breaks

Stage 2 — AOT Autograd

Dynamo's FX graph is forward only. To compile a training step we also need backward. AOT (Ahead-Of-Time) Autograd takes the forward graph, runs PyTorch's autograd machinery once at compile time to derive the backward graph, then hands both to Inductor as a single combined graph.

It also decomposes high-level ops into a smaller set of "primitive" ops. For example, aten::softmax decomposes into max, sub, exp, sum, div. The decomposition exposes fusion opportunities that the original opaque op hides. This is why torch.compile often beats eager softmax: the decomposed primitives can be fused into one Triton kernel and chained with adjacent pointwise ops (dropout, scale, residual). They do not get fused into the matmul body itself — matmul stays in cuBLAS — but they can avoid the HBM round-trip between matmul output and the next pointwise op via Inductor's epilogue-fusion path.

AOT Autograd also handles the functionalization step: in-place ops (add_, copy_) are rewritten as out-of-place versions, so the downstream pipeline can reason about a pure data-flow graph. This is one of the more invasive transformations and occasionally causes subtle bugs in code that depended on in-place semantics.

2D · AOT autograd flow — click an op for its backward partner

The forward graph is on top, the backward graph mirrors it on the bottom (in reverse order). Click any forward op to see which backward op it generates and which intermediate tensors get saved. The "saved" tensors are the activations — they live in HBM between forward and backward and dominate the activation memory budget.

AOT autograd · forward + backward in one graph
Solid arrows = data flow. Dashed arrows = saved tensors crossing from forward to backward. Click any forward op.
selected
— click a fwd op —
backward op
saved tensors
recomputable?

Stage 3 — Inductor, the codegen

Inductor takes a pure FX graph and outputs:

Inductor's three big optimisations:

  1. Vertical fusion. Chains of elementwise / pointwise ops become a single Triton kernel — the lesson-16 fusion case.
  2. Horizontal fusion. Independent ops with the same shape can share a kernel launch.
  3. Memory planning. The compiler knows the lifetime of every tensor and reuses buffers, avoiding allocator churn for known intermediates. This is one of the larger wins in practice.

The four modes of torch.compile

ModeWhat it addsUse case
"default"Dynamo + AOT + InductorMost training. 1.3–2× typical speedup.
"reduce-overhead"Default + CUDA Graphs wrapping (lesson 19)Decode-heavy inference. Drops launch overhead.
"max-autotune"Default + exhaustive autotuning of matmul shapesLong training runs that justify minutes of compile time.
"max-autotune-no-cudagraphs"max-autotune but skip CUDA Graph captureDynamic-shape workloads.

Three failure modes, in order of how often they bite

  1. Graph breaks. A 100-line model with one .item() deep in it gets cut into many small graphs. The Python tax you wanted to avoid is back. Diagnose with TORCH_LOGS="graph_breaks"; fix by removing data-dependent Python.
  2. Recompilation. Every new shape, dtype, or device combination triggers a fresh compile. By default Dynamo has a recompile budget (typically 8 per call site); past that it gives up and falls back to eager. Diagnose with TORCH_LOGS="recompiles"; fix by using dynamic=True or by padding inputs to a fixed shape.
  3. Inductor fallbacks. Some ops (sparse, custom, weird-stride) aren't supported by Inductor and become a synchronous cuBLAS/eager call. Doesn't break correctness but kills the fusion win. Diagnose with TORCH_LOGS="output_code".

2D · failure-modes timeline

A simulated training loop. Step time is plotted across 200 steps; each failure mode can be toggled on. Watch the spikes: graph breaks add a small constant overhead per step; recompilations land a huge spike every time a new shape is seen; fallbacks add per-step overhead that doesn't recover.

Failure-mode timeline · step time over a training run
Baseline = clean compiled run at 50 ms/step. Toggle each failure mode to overlay its impact. The y-axis is step time (ms); the x-axis is step number.
median step (ms)
p99 step (ms)
total time (s)
vs clean compile

Compile cost — when is it worth it?

First-call compile time is real: Dynamo trace + AOT decomposition + Inductor codegen + autotuning often adds 10–60 seconds to the first forward, then per-shape recompiles on top. The breakeven against eager is somewhere around 1000 steps for training and ~100 prompts for inference (since each saves a few ms).

If you're doing a long training run, compile pays off in the first hour. If you're doing one-off research code, the compile time can dominate. PyTorch's compile cache (on-disk by default since 2.5) makes the second run of the same model much faster.

Where torch.compile beats hand-written CUDA — and where it doesn't

Inference vs training compile

For training, the AOT-Autograd path matters: backward is compiled too, and the activation-saving optimisation (Inductor can decide what to recompute vs save based on costs) is a real lever. For inference, you usually want mode="reduce-overhead" — same compile path plus CUDA Graphs around the compiled forward. The graph-break problem is worse at inference (decoder loops have inherent Python in them) but the per-step compile is amortised over thousands of tokens.

Interactive · what does the compile pipeline keep, fuse, fall back?

Pick a forward pass shape and trace it through Dynamo's eyes: which calls become graph breaks (and why), which decompose, which fuse. The widget is a simulation of how a few common motifs land in the pipeline — illustrative, not exhaustive.

Pipeline trace · what survives capture
Each row is a line of "user code". Columns show whether Dynamo captures it, AOT decomposes it, Inductor fuses it, or something causes a graph break.
linedynamoAOT decomposeinductor fusenotes
graph breaks
recompiles likely
expected speedup
verdict
Takeaway
torch.compile is three stages: capture, derive backward, codegen. Each can fall back. Most of the speedup comes from Inductor's vertical fusion + memory planning; the rest from CUDA Graphs in reduce-overhead. The job of the engineer is to keep the graph clean — no .item(), no Python control flow on tensor values, no exotic ops — and let the compiler do its work.