system_ml / 13 · pytorch internals lesson 13 / 19

PyTorch internals — what x @ W actually does

From Python line to GPU kernel is roughly twelve translation layers. The performance of every distributed system above this rides on how few of them you make Python touch.

The naive picture vs the real one

The mental model most engineers start with: "PyTorch calls cuBLAS, cuBLAS runs on the GPU." That's true at the bottom. But between your Python y = x @ W and the cuBLAS kernel, every line does the following:

  1. Python interpreter looks up __matmul__ on the tensor.
  2. Tensor's C++ binding (via pybind11) calls into ATen.
  3. ATen's dispatcher picks the right implementation based on the input tensors' device, dtype, layout, and other "dispatch keys" (autograd, autocast, vmap, etc.).
  4. The autocast layer (if enabled) fires first, casting inputs to bf16/fp16 and redispatching to Autograd.
  5. The autograd layer then records a node in the backward graph if any input requires grad — on the already-cast tensors.
  6. The backend kernel (CUDA, here) finally runs — which itself is dispatch into cuBLAS, with shape/strides selection, workspace allocation, and stream binding.
  7. cuBLAS picks one of dozens of pre-tuned matmul implementations and launches it on the GPU's default stream.

The work of the dispatcher and autograd is microseconds per call. That sounds free until you remember that decode emits one token per forward, and a forward of a 70B model is many kernel launches. The framework overhead is comparable to the actual HBM-bound compute of decode (lesson 11). This is the gap torch.compile and CUDA Graphs (lessons 18, 19) close.

The dispatcher — one function, many implementations

The dispatcher is the data structure that decides which version of aten::matmul to run. A function in ATen isn't a single C++ function — it's a set of implementations indexed by dispatch keys. Every tensor carries a bitmap of dispatch keys: which device (CPU, CUDA, MPS, XLA), which layout (Strided, Sparse), which autograd backend, plus optional layers like Autocast, Functorch, NestedTensor.

When you call x @ W, the dispatcher walks the priority list of keys present on the inputs, and for each one runs the registered kernel for that level. A typical call stack:

x.shape = (4, 4096), W.shape = (4096, 4096), both CUDA fp32, requires_grad=True
                            │
                            ▼
       Python Tensor.__matmul__   ←─ pybind11 entry
                            │
                            ▼
       at::matmul  (dispatcher)   ←─ key set: {CUDA, Autograd, Autocast?}
                            │
              ┌─────────────┼─────────────┐
              ▼             ▼             ▼
         Autocast       Autograd        CUDA
        (cast bf16   (record op for   (call cuBLAS)
         if enabled,    backward, on
         fires first,   already-cast
         redispatches)  tensors)

Each layer's kernel optionally calls callBoxedNext (or its unboxed sibling) to redispatch down to the next key. Autograd's kernel records the op then redispatches to the underlying backend; the CUDA kernel does the real work; the result bubbles back up. (The narrower term "fall-through" in PyTorch's source refers specifically to keys that have no registered kernel and pass through automatically — Autograd/Autocast both have real kernels that explicitly redispatch.)

Animated · one matmul through the dispatcher stack

Below: a single torch.matmul(a, b) as it travels down the dispatch stack — Python entry → Autocast (casts to bf16 if active, fires first) → Autograd (records MmBackward0 on already-cast tensors) → CUDA backend → cuBLAS heuristic → kernel launch on the GPU. Each layer adds its own slice to the running "Python overhead" counter on the right. Hit play and watch where the microseconds go.

Dispatcher trace · scrub or play
Each layer in the dispatcher key stack contributes some overhead before any GPU work happens. Toggle autograd / autocast to add or remove their slices. The counter on the right is the cumulative host-side cost up to the current step.
current layer
host overhead so far
0 μs
total host overhead
— μs
step
0 / —

The autograd graph — built as a side effect

You don't explicitly build a computational graph in PyTorch — the dispatcher's autograd layer does it for you, every time an op is called on a tensor with requires_grad=True. The graph is a DAG of Node objects (each one a "function with a saved context"), with edges pointing to the inputs that need gradients.

Concretely, when you do y = x @ W with both requires_grad: the autograd layer creates an MmBackward0 node, saves x and W as context (because the gradient of matmul needs both operands), and attaches the node to y.grad_fn. When you later call loss.backward(), PyTorch walks the graph in reverse topological order, calling each node's apply() method to compute input gradients from output gradients.

Why this design is fast — and where it bites
Eager autograd is "tape-style": the graph is constructed as you go, so dynamic control flow (if, while, varying shapes) just works. The cost is that the graph is rebuilt every forward pass — Python and dispatcher overhead per op are paid every step. Static-graph frameworks (XLA, TensorFlow 1) avoided this by capturing the graph once and replaying; PyTorch's bet was that the developer ergonomics paid for the runtime cost. The compromise — torch.compile — keeps the eager API and tries to capture chunks of static graph behind it.

Saving the right things — and the activation memory wall

Every autograd node has a "context" of tensors it saves for backward. MmBackward0 saves both inputs of the matmul. ReluBackward0 saves the output (so backward can mask gradients where the output was 0). SoftmaxBackward0 saves the output. Etc.

These saved tensors are exactly what lesson 01 called "activation memory". They live on the GPU until backward consumes them. For a 70B model, this is the GB-scale state we tackled with activation checkpointing — which works by not saving most intermediates and recomputing them during backward. Mechanism-wise, checkpointing wraps a sub-graph in a custom Function whose backward re-runs the forward.

2D · autograd graph builder

Pick an expression. The widget builds the forward graph (top) node by node, and the backward graph (bottom, mirrored) grows in lock-step. Each forward op leaves a grad_fn on its output tensor; the backward graph's edges are exactly the saved-tensor references the forward pass left behind. Notice that leaf tensors with requires_grad=True are the only nodes whose .grad ever gets populated.

Autograd graph · forward (top) + backward (bottom)
Step through to grow the graph one op at a time. The bottom half is the chain of *Backward nodes that loss.backward() will walk in reverse. Solid nodes are tensors; rounded boxes are grad_fns.
expression
forward nodes
0
backward nodes
0
tensors w/ grad_fn
0

Eager mode vs graph mode

PropertyEager (default)Graph (compile/JIT)
Python overhead per op~10 μsAmortised once
Op fusionNone (kernel boundaries)Lots (lesson 16, 18)
Memory planningPer-op alloc/freeWhole-program plan
Dynamic control flowFreeGraph breaks (lesson 18)
DebuggabilitySet breakpoints anywhereNeed to inspect captured IR

Three "graph modes" coexist:

The cost of Python in the decode loop

Numbers to keep in head (rough, very workload-dependent): each PyTorch op in eager mode costs roughly

So small-op eager-mode cost is typically 15–50 μs per op before any GPU work. For a transformer decode step (~hundreds of ops at batch=1), the framework overhead can easily approach the kernel time on a 7B-class model decoded on an H100 (where the actual compute is also a few ms). Take any of these numbers as a fermi estimate, not a benchmark — the exact split depends on op size, dtype, and PyTorch version.

Animated · the per-op Python tax timeline

Two timelines, side-by-side: N small unfused ops vs M fused ops. The fused version does identical math but in fewer kernel launches. Watch the unfused timeline pay a fresh slice of Python overhead per op — the GPU sits idle between launches when the dispatcher is the bottleneck. Slide "ops/step" to see how the gap widens as op count grows.

Python tax over time · unfused vs fused
Red blocks = Python / dispatcher overhead (host-side, GPU idle). Blue blocks = GPU kernel running. Same total math; the fused version has many fewer red blocks because it pays the host tax once per fused kernel.
unfused total
fused total
tax fraction (unfused)
speedup

Strides, views, contiguity

A PyTorch tensor is a tuple (storage, offset, shape, strides). Many ops return a view: a new tensor descriptor pointing into the same storage. transpose, permute, view, squeeze, slicing — all return views, no data movement. The trap: most kernels want their inputs contiguous (stride[i] = product of shape[i+1:]). A non-contiguous input often triggers an implicit copy kernel — invisible HBM bandwidth burned.

This is a real overhead source in custom code. Whenever you see a mystery aten::copy_ in your profiler, the upstream op produced a non-contiguous view and the downstream kernel forced a contiguous copy. Cure: explicit .contiguous() at the point you control, or pick a different op order.

Interactive · per-op overhead and amortisation

Pick a model decode size and a batch. The widget plots: per-step kernel time (the work), Python framework overhead (the launch cost), and what torch.compile / CUDA Graphs would buy. Try batch=1, model=7B: the framework is half the cost. Crank batch to 32: the framework's fraction drops because each kernel does more work per launch.

Per-decode overhead: kernel time vs framework tax
Toy model: ops/step kernel launches, each one paying launch_us of Python + kernel_us of GPU work. Bars stack; the framework slice is what graph capture removes.
step time (eager)
step time (graph)
overhead fraction
graph speedup
Takeaway
Every line of PyTorch traverses a fixed-cost dispatcher + autograd stack. Eager mode pays that cost per op; graph mode (torch.compile, CUDA Graphs) pays it once. When you're decoding small batches of large models, framework overhead is a first-class system bottleneck, not a footnote — and the rest of Part IV exists to flatten it.