CUDA Graphs & TensorRT — serve-time graph capture
Decode emits one token per forward; each forward is hundreds of kernel launches; each launch is ~1–2 μs of GPU-side overhead. Capture the launches once, replay forever — and at the limit, hand the whole model to TensorRT and let it pick a globally-optimised plan.
What launches cost
From lesson 13 we counted host-side Python overhead (~5–10 μs per op). There's a second cost we haven't isolated yet: the device-side launch overhead, the time the GPU spends accepting and starting each kernel. On H100 this is roughly:
- ~1 μs per kernel for the GPU command processor to receive and queue it.
- Plus ~0.5–2 μs of "bubble" between kernels when the GPU has finished one and is waiting for the next to be dispatched.
This is invisible to most training code because backward kernels are large enough that the bubble is negligible. But for decode (~hundreds of launches per token, each kernel ~10–50 μs for a 7B model), the bubble is a meaningful fraction of step time. Even with torch.compile reducing host overhead, the device-side bubble between dozens of small kernels per attention block doesn't go away.
CUDA Graphs — capture once, replay many times
A CUDA Graph is a recorded sequence of kernel launches (and memcpys, and events) that the driver hands to the GPU as a single submission. The GPU's command processor sees one launch, expands it into the captured sequence internally, and runs through them with no host involvement and minimal between-kernel bubble.
Two ways to build one:
- Stream capture. Wrap a regular sequence of kernel launches in
cuStreamBeginCapture/cuStreamEndCapture. The driver records every launch on that stream into a graph, instead of executing them immediately. - Manual construction. Build the graph as a DAG of node objects via the CUDA Graph API. More control, more code; rarely used by frameworks.
PyTorch exposes this via torch.cuda.CUDAGraph:
g = torch.cuda.CUDAGraph()
static_input = torch.empty_like(x)
static_output = torch.empty_like(model(x))
# warmup (build cuDNN/cuBLAS workspaces, autotune kernels)
for _ in range(3):
static_output = model(static_input)
# capture
with torch.cuda.graph(g):
static_output = model(static_input)
# replay — fast
for batch in loader:
static_input.copy_(batch)
g.replay()
use(static_output)
Replay is a single device-side submission. For a decoder model with batch=1, this typically cuts step time by 20–40% — entirely from launch overhead reduction.
Animated · launch-overhead timeline
Same workload, two timelines. Top: eager dispatch — each kernel launch is preceded by a Python/CUDA-driver bubble (~10 μs). Bottom: CUDA Graph replay — one submission, kernels run back-to-back with minimal gap. Hit play to see real-time progress; both lines run at the same kernel speed, only the bubbles differ.
The shape-dependence trap
A captured graph has every tensor's shape, stride, and address baked in. Two things follow:
- Inputs and outputs must be at fixed addresses. You can't pass a new tensor each call; you copy the new data into the same buffer ("static input"). For inference this matches naturally (
static_input.copy_(batch)); for training the data loader has to write into a pinned destination. - Shapes must be fixed. A different sequence length, batch size, or KV-cache length needs a different captured graph. Production serving stacks capture one graph per (batch, sequence) bucket, with
(1, 1, 2, 4, 8, 16, 32)being common.
This is also why mode="reduce-overhead" in torch.compile only kicks in when shapes are stable: the compiler wraps the compiled forward in a CUDA Graph, but if shapes change it has to give up and re-capture (or fall back to ungraphed execution).
max_length and switch between them (memory-expensive), or (b) over-allocate KV to max_length and write into the next slot each step (one graph, used everywhere — vLLM does this with paged KV). The second is why paged attention was the right architecture for graph capture.
2D · the shape-dependence trap
A graph was captured at batch size B_capture = 4. Slide the request batch size. If you land on the captured shape: instant replay (fast). If you don't: the graph invalidates and you pay re-capture (slow) or fall back to ungraphed execution. Production stacks capture one graph per bucket so the slider always lands on a marker.
TensorRT — graph capture's professional cousin
TensorRT (NVIDIA's inference compiler) sits one level up from CUDA Graphs. The classical TRT path takes an ONNX model (or its own builder API) and outputs a plan file — a pre-compiled, layer-by-layer optimised binary. (TRT-LLM, the transformer-specific superstructure, doesn't go through ONNX; it has its own Python builder that consumes Hugging Face / Megatron checkpoints and emits the same plan-file format.) The plan file includes:
- Kernel selection per layer, picked by trying every available implementation in cuBLAS/cuDNN and a TRT-internal library, and timing them on real inputs at build time. This is "autotuning per shape," done once.
- Layer fusion much more aggressive than Inductor: convolution + bias + activation, GEMM + dequant + activation, multi-head attention as one kernel.
- Precision-aware planning: it'll decide which layers run in fp16 vs int8 based on accuracy constraints you pass.
- Memory layout decisions (NCHW vs NHWC for conv-heavy models; for transformers it's all just GEMM).
- Optional persistent kernels: one kernel grid stays resident on all SMs and pulls work off an in-GPU queue, so repeated invocations of that kernel (e.g. the per-token attention/MLP loop) collapse from N launches into 1. The model still has many distinct kernels; persistent kernels eliminate launches for hot ones, not all of them.
TensorRT-LLM is a higher-level TRT-based stack specifically for transformer LLMs. It adds:
- Pre-tuned attention kernels (multi-head, GQA, MQA, FlashAttention-style).
- KV-cache-aware paged attention.
- In-flight batching (the continuous-batching scheduler from the vLLM series, in TRT form).
- fp8 / int8 / int4 quantisation paths with calibration.
- Speculative decoding, multi-LoRA, prefix-cache.
TRT-LLM is what most NVIDIA reference benchmarks (MLPerf, etc.) run. The serving-time speedup over a well-tuned vLLM is usually 10–30% on H100, with the gap closing as vLLM picks up similar techniques.
3D · TensorRT layer plan, isometric
A TRT plan file is a stack of layers, and for each layer the builder picked one pre-tuned kernel per supported input shape. Below, the stack is shown isometrically; each floor is one layer; the cards in the floor are the per-shape kernel choices. Click a layer to see its plan entries — the bottom strip shows which kernel runs for each shape bucket.
Where the stack ends — and where the wins are
| Mode | What it removes | Typical decode win on a 7B model |
|---|---|---|
| Eager PyTorch | nothing — baseline | 1.0× |
torch.compile default | Python overhead, some kernel-launch fusion | ~1.3–1.6× |
torch.compile reduce-overhead (CUDA Graphs) | device-side bubble, residual host overhead | ~1.6–2.0× |
| vLLM (paged attn + Triton/CUDA kernels + CUDA Graphs) | same + paged KV + continuous batching | ~2.5–4× (at scale) |
| TensorRT-LLM (custom CUDA kernels + persistent kernels + autotuned per shape) | everything torch.compile can't | ~3–5× (at scale) |
The relative wins shrink at higher batch sizes — once you're at batch=32 decoding, the per-launch overhead matters less because each launch is doing more work. The biggest wins from graph capture are at the latency-sensitive batch=1 chat regime.
The asymmetry between training and inference compile
Training compile and inference compile look similar but have opposite priorities:
- Training: shapes are stable (the batch is yours), backward exists and is the bulk of compute. Compile pays off in the first hour.
torch.compileis the right level. - Inference: shapes vary per request (different prompt lengths, different KV positions), no backward, latency is the metric. Need to pre-plan for the shapes you'll see. Graph capture per shape bucket + paged KV + autotuned kernels per shape is the answer.
torch.compilealone is not enough; production stacks add TRT-LLM or vLLM machinery on top.
What we walked through — the whole stack, top to bottom
- Lessons 01–10: the cluster-level layout — how the model is sharded across thousands of GPUs.
- Lessons 11–12: the inference-time layer — how the model is reassembled to serve users.
- Lesson 13: the PyTorch framework that lets you write the model in Python.
- Lesson 14: the precisions the math runs in.
- Lesson 15: the allocator that recycles memory between ops.
- Lesson 16: the fusion principle — kernels' boundaries are HBM round-trips.
- Lesson 17: Triton, the DSL that lets one engineer write a fused kernel.
- Lesson 18:
torch.compile, which writes the fused kernels for you. - Lesson 19 (here): CUDA Graphs and TensorRT, which capture the whole forward into one replayable artifact.
Every layer of this stack exists because the layer above wanted something — fewer launches, less overhead, less memory, more fusion, faster reads. Reading top-down, that's why each layer exists. Reading bottom-up, that's why each layer is shaped the way it is.
Interactive · the launch-overhead simulator
Pick a model (ops per forward), a batch, and the per-launch cost. The widget compares eager, compile-default, compile-reduce-overhead, and TRT-LLM-style. The point is: as ops-per-forward grows and per-launch work shrinks (i.e. you're in the decode regime), the gap between modes widens.