PyTorch internals — what x @ W actually does
From Python line to GPU kernel is roughly twelve translation layers. The performance of every distributed system above this rides on how few of them you make Python touch.
The naive picture vs the real one
The mental model most engineers start with: "PyTorch calls cuBLAS, cuBLAS runs on the GPU." That's true at the bottom. But between your Python y = x @ W and the cuBLAS kernel, every line does the following:
- Python interpreter looks up
__matmul__on the tensor. - Tensor's C++ binding (via pybind11) calls into ATen.
- ATen's dispatcher picks the right implementation based on the input tensors' device, dtype, layout, and other "dispatch keys" (autograd, autocast, vmap, etc.).
- The autocast layer (if enabled) fires first, casting inputs to bf16/fp16 and redispatching to Autograd.
- The autograd layer then records a node in the backward graph if any input requires grad — on the already-cast tensors.
- The backend kernel (CUDA, here) finally runs — which itself is dispatch into cuBLAS, with shape/strides selection, workspace allocation, and stream binding.
- cuBLAS picks one of dozens of pre-tuned matmul implementations and launches it on the GPU's default stream.
The work of the dispatcher and autograd is microseconds per call. That sounds free until you remember that decode emits one token per forward, and a forward of a 70B model is many kernel launches. The framework overhead is comparable to the actual HBM-bound compute of decode (lesson 11). This is the gap torch.compile and CUDA Graphs (lessons 18, 19) close.
The dispatcher — one function, many implementations
The dispatcher is the data structure that decides which version of aten::matmul to run. A function in ATen isn't a single C++ function — it's a set of implementations indexed by dispatch keys. Every tensor carries a bitmap of dispatch keys: which device (CPU, CUDA, MPS, XLA), which layout (Strided, Sparse), which autograd backend, plus optional layers like Autocast, Functorch, NestedTensor.
When you call x @ W, the dispatcher walks the priority list of keys present on the inputs, and for each one runs the registered kernel for that level. A typical call stack:
x.shape = (4, 4096), W.shape = (4096, 4096), both CUDA fp32, requires_grad=True
│
▼
Python Tensor.__matmul__ ←─ pybind11 entry
│
▼
at::matmul (dispatcher) ←─ key set: {CUDA, Autograd, Autocast?}
│
┌─────────────┼─────────────┐
▼ ▼ ▼
Autocast Autograd CUDA
(cast bf16 (record op for (call cuBLAS)
if enabled, backward, on
fires first, already-cast
redispatches) tensors)
Each layer's kernel optionally calls callBoxedNext (or its unboxed sibling) to redispatch down to the next key. Autograd's kernel records the op then redispatches to the underlying backend; the CUDA kernel does the real work; the result bubbles back up. (The narrower term "fall-through" in PyTorch's source refers specifically to keys that have no registered kernel and pass through automatically — Autograd/Autocast both have real kernels that explicitly redispatch.)
Animated · one matmul through the dispatcher stack
Below: a single torch.matmul(a, b) as it travels down the dispatch stack — Python entry → Autocast (casts to bf16 if active, fires first) → Autograd (records MmBackward0 on already-cast tensors) → CUDA backend → cuBLAS heuristic → kernel launch on the GPU. Each layer adds its own slice to the running "Python overhead" counter on the right. Hit play and watch where the microseconds go.
The autograd graph — built as a side effect
You don't explicitly build a computational graph in PyTorch — the dispatcher's autograd layer does it for you, every time an op is called on a tensor with requires_grad=True. The graph is a DAG of Node objects (each one a "function with a saved context"), with edges pointing to the inputs that need gradients.
Concretely, when you do y = x @ W with both requires_grad: the autograd layer creates an MmBackward0 node, saves x and W as context (because the gradient of matmul needs both operands), and attaches the node to y.grad_fn. When you later call loss.backward(), PyTorch walks the graph in reverse topological order, calling each node's apply() method to compute input gradients from output gradients.
if, while, varying shapes) just works. The cost is that the graph is rebuilt every forward pass — Python and dispatcher overhead per op are paid every step. Static-graph frameworks (XLA, TensorFlow 1) avoided this by capturing the graph once and replaying; PyTorch's bet was that the developer ergonomics paid for the runtime cost. The compromise — torch.compile — keeps the eager API and tries to capture chunks of static graph behind it.
Saving the right things — and the activation memory wall
Every autograd node has a "context" of tensors it saves for backward. MmBackward0 saves both inputs of the matmul. ReluBackward0 saves the output (so backward can mask gradients where the output was 0). SoftmaxBackward0 saves the output. Etc.
These saved tensors are exactly what lesson 01 called "activation memory". They live on the GPU until backward consumes them. For a 70B model, this is the GB-scale state we tackled with activation checkpointing — which works by not saving most intermediates and recomputing them during backward. Mechanism-wise, checkpointing wraps a sub-graph in a custom Function whose backward re-runs the forward.
2D · autograd graph builder
Pick an expression. The widget builds the forward graph (top) node by node, and the backward graph (bottom, mirrored) grows in lock-step. Each forward op leaves a grad_fn on its output tensor; the backward graph's edges are exactly the saved-tensor references the forward pass left behind. Notice that leaf tensors with requires_grad=True are the only nodes whose .grad ever gets populated.
Eager mode vs graph mode
| Property | Eager (default) | Graph (compile/JIT) |
|---|---|---|
| Python overhead per op | ~10 μs | Amortised once |
| Op fusion | None (kernel boundaries) | Lots (lesson 16, 18) |
| Memory planning | Per-op alloc/free | Whole-program plan |
| Dynamic control flow | Free | Graph breaks (lesson 18) |
| Debuggability | Set breakpoints anywhere | Need to inspect captured IR |
Three "graph modes" coexist:
- TorchScript (legacy). A Python subset and tracer; now in maintenance mode.
- torch.compile (lesson 18). Capture-by-bytecode (Dynamo) + Inductor codegen. The current default for "make eager code faster."
- CUDA Graphs (lesson 19). Below the framework level — captures a sequence of kernel launches into a single replayable graph.
The cost of Python in the decode loop
Numbers to keep in head (rough, very workload-dependent): each PyTorch op in eager mode costs roughly
- ~5 μs for the actual kernel launch (driver submission, command-processor queue)
- ~10–30 μs for Python interpreter + pybind + dispatcher + autograd bookkeeping (the "Python tax")
- actual kernel time on top of that
So small-op eager-mode cost is typically 15–50 μs per op before any GPU work. For a transformer decode step (~hundreds of ops at batch=1), the framework overhead can easily approach the kernel time on a 7B-class model decoded on an H100 (where the actual compute is also a few ms). Take any of these numbers as a fermi estimate, not a benchmark — the exact split depends on op size, dtype, and PyTorch version.
Animated · the per-op Python tax timeline
Two timelines, side-by-side: N small unfused ops vs M fused ops. The fused version does identical math but in fewer kernel launches. Watch the unfused timeline pay a fresh slice of Python overhead per op — the GPU sits idle between launches when the dispatcher is the bottleneck. Slide "ops/step" to see how the gap widens as op count grows.
Strides, views, contiguity
A PyTorch tensor is a tuple (storage, offset, shape, strides). Many ops return a view: a new tensor descriptor pointing into the same storage. transpose, permute, view, squeeze, slicing — all return views, no data movement. The trap: most kernels want their inputs contiguous (stride[i] = product of shape[i+1:]). A non-contiguous input often triggers an implicit copy kernel — invisible HBM bandwidth burned.
This is a real overhead source in custom code. Whenever you see a mystery aten::copy_ in your profiler, the upstream op produced a non-contiguous view and the downstream kernel forced a contiguous copy. Cure: explicit .contiguous() at the point you control, or pick a different op order.
Interactive · per-op overhead and amortisation
Pick a model decode size and a batch. The widget plots: per-step kernel time (the work), Python framework overhead (the launch cost), and what torch.compile / CUDA Graphs would buy. Try batch=1, model=7B: the framework is half the cost. Crank batch to 32: the framework's fraction drops because each kernel does more work per launch.