The kernel stack — FlashInfer, Triton, CUDA graphs
RadixAttention and the scheduler save compute; the kernels make the surviving compute fast. SGLang's kernel choices are not "buy everything"; they're a sequence of measured responses to the next remaining bottleneck. Walk down the layers and each picks up the slack of the previous one.
The runtime, drawn once
Three boxes warrant individual attention because they are where SGLang's choices differ measurably from naive PyTorch and from vLLM.
FlashInfer — why a separate library
FlashAttention (Dao et al.) is a class of kernels — FA1, FA2, FA3, plus paged variants. FlashInfer (Ye et al., CMU) is a separate, SGLang-aligned library written specifically for the patterns serving inference creates:
- Ragged batches. Decode steps mix sequences of very different lengths; FlashInfer's kernel exposes a CSR-like layout where each sequence's KV is addressed by its block table independently. No padding to the max length.
- Paged KV native. The kernel takes
kv_indices,kv_indptr,page_sizedirectly — no separate gather step. - Pre-padded vs. variable. Two API tiers: a low-overhead path for fixed-shape decode (used inside the CUDA-graph capture), and a flexible path for prefill / varying batch shape.
- MLA-aware. DeepSeek-V3 / R1 use Multi-head Latent Attention. FlashInfer has a fused MLA decode kernel (lesson 09 covers MLA).
- Quantization. fp8 KV, int4 weight × fp16 act, AWQ — all supported in the same kernel family without leaving the API.
Why not just use FlashAttention-3? Historically FA shipped the attention math and left the serving glue (paged + ragged + variable batch) to the engine. Modern FA3 does ship paged-KV support, and SGLang now supports FA3 as a backend alongside FlashInfer — but the two libraries make different default choices about the API surface and ragged-batching layout. FlashInfer fills the same role as cuDNN: a library of kernels co-designed with the engine.
The two-tier decode kernel
The reduction in wasted FLOPs is workload-dependent but typically 1.5–3× on decode batches where length variance is high.
Triton — where the "small" kernels live
Attention dominates wall-clock, but it's not the only kernel. SGLang writes several supporting kernels in Triton — sampling, RMSNorm, RoPE, grammar mask application — for two reasons:
- Composition matters more than peak. A sampling kernel that fuses logit masking, temperature scaling, top-p truncation, and the categorical sample into one launch is dramatically faster than four PyTorch ops, even if each individual op is well-tuned.
- Triton lets you ship a custom variant in a day. Speculative-decoding-aware sampling, grammar-aware sampling, beam-search-aware sampling — these are SGLang-specific shapes that no upstream library will ship pre-built.
The fused sampler, as a model
# conceptually, the fused sampling step:
@triton.jit
def fused_sample(logits_ptr, mask_ptr, temp, top_p, out_ptr, ...):
block_id = tl.program_id(0)
logits = tl.load(logits_ptr + block_id * V + offsets)
mask = tl.load(mask_ptr + block_id * V + offsets)
logits = tl.where(mask, logits / temp, -float("inf")) # grammar + temp
probs = softmax_top_p(logits, top_p) # top-p in registers
token = categorical_sample(probs, rng_state)
tl.store(out_ptr + block_id, token)
One launch per step instead of five. The launch savings are per-step, not per-sequence — at batch 8 the unfused path still launches 5 kernels per step, just over a larger batch each. The fused win is the elimination of the 4 inter-kernel synchronization points, meaningful when a decode step is only ~2 ms to begin with.
CUDA graphs — eliminating launch overhead at small batch
Decode at small batch (say, 4 sequences) is dominated by Python and CUDA-launch overhead, not GPU compute. For Llama-70B at batch 4 on H100:
| Phase | Time |
|---|---|
| Python framework overhead (per layer) | ~50 µs × 80 layers = 4 ms |
| CUDA kernel launches (per step) | ~3 ms |
| Actual kernel time on GPU | ~6 ms |
| Total per decode step | ~13 ms |
Half of every step is wasted on launching kernels that the GPU could execute in their own time. CUDA Graphs let you record a sequence of kernel launches once and replay them from a single API call. The replay is ~10 µs of overhead total, not 3 ms.
What gets graph-captured
You can only graph kernels whose shapes are fixed at capture time. SGLang captures the decode path — but decode shapes aren't fixed because batch size varies. The fix: capture multiple graphs, one per (batch_size, kv_layout) configuration the engine expects to see. At runtime, pick the matching graph. Configurations not pre-captured fall back to per-launch mode.
Prefill doesn't graph well — variable prompt lengths and ragged batches change shape too often. SGLang prefills in normal eager mode, then switches to graph for the decode loop.
The kernel choices read as a chain of consequences
- Workloads are dominated by attention bandwidth → tile attention (FlashAttention-style) and put KV in pages (PagedAttention-style). FlashInfer ships both fused.
- Decode batches are ragged → kernels take CSR offsets instead of padded tensors.
- Sampling, RMSNorm, RoPE are now visible on the timeline → fuse them into Triton kernels.
- At small batch, the timeline is dominated by Python and launch overhead → record the decode path as a CUDA graph; replay.
- Quantization (fp8, int4, AWQ) reduces bandwidth → push the dtypes through to the attention kernel; needs FlashInfer's quant-aware paths.
Each step is a measured response, not a stylistic choice. SGLang publishes detailed profiler traces; if you read them, the chain above is what you see.
What this looks like in the source
| Layer | File / module |
|---|---|
| FlashInfer attention wrapper | python/sglang/srt/layers/attention/flashinfer_backend.py |
| Triton sampler | python/sglang/srt/layers/sampler.py |
| CUDA-graph runner | python/sglang/srt/model_executor/cuda_graph_runner.py |
| Model executor / batch builder | python/sglang/srt/model_executor/model_runner.py |
If you read these in that order, you can trace one decode step from a tokenized request through scheduler → batch builder → CUDA-graph replay → FlashInfer kernel call → sampled token in about an hour.
What lesson 09 builds
The kernels are fast — but a 70B model doesn't fit on one GPU and a 671B MoE model doesn't fit on one node. Lesson 09 covers the three parallelism axes (tensor, data-attention, expert) that let the same kernels run on a model that's too big for one device. Each axis interacts with a different layer of the model and a different bottleneck the kernels have not removed.
Interactive · the overhead chain at small batch
Slide batch size. Watch which overhead dominates and what each layer removes.