all_lessons / sglang / 08 · kernel stack lesson 8 / 11

The kernel stack — FlashInfer, Triton, CUDA graphs

RadixAttention and the scheduler save compute; the kernels make the surviving compute fast. SGLang's kernel choices are not "buy everything"; they're a sequence of measured responses to the next remaining bottleneck. Walk down the layers and each picks up the slack of the previous one.

The runtime, drawn once

HTTP / OpenAI server  ·  SGLang program runtime scheduler  ·  RadixAttention cache  ·  constraint compiler model executor (PyTorch)  ·  weight loader  ·  KV pool FlashInfer paged attention ragged batching · MLA Triton kernels sampling · RMSNorm grammar mask apply CUDA-graph capture decode replay path eliminates launch overhead CUDA / cuBLAS / cuDNN  ·  NCCL  ·  NVRTC GEMM, NVLink all-reduce, on-the-fly Triton compile

Three boxes warrant individual attention because they are where SGLang's choices differ measurably from naive PyTorch and from vLLM.

FlashInfer — why a separate library

FlashAttention (Dao et al.) is a class of kernels — FA1, FA2, FA3, plus paged variants. FlashInfer (Ye et al., CMU) is a separate, SGLang-aligned library written specifically for the patterns serving inference creates:

Why not just use FlashAttention-3? Historically FA shipped the attention math and left the serving glue (paged + ragged + variable batch) to the engine. Modern FA3 does ship paged-KV support, and SGLang now supports FA3 as a backend alongside FlashInfer — but the two libraries make different default choices about the API surface and ragged-batching layout. FlashInfer fills the same role as cuDNN: a library of kernels co-designed with the engine.

The two-tier decode kernel

naive padded batch sequences padded to max length. wasted compute = grey region of each row. FlashInfer ragged + paged one flat tensor, three CSR offsets: kv_indptr, qo_indptr, kv_indices. Kernel walks block tables, no padding.

The reduction in wasted FLOPs is workload-dependent but typically 1.5–3× on decode batches where length variance is high.

Triton — where the "small" kernels live

Attention dominates wall-clock, but it's not the only kernel. SGLang writes several supporting kernels in Triton — sampling, RMSNorm, RoPE, grammar mask application — for two reasons:

  1. Composition matters more than peak. A sampling kernel that fuses logit masking, temperature scaling, top-p truncation, and the categorical sample into one launch is dramatically faster than four PyTorch ops, even if each individual op is well-tuned.
  2. Triton lets you ship a custom variant in a day. Speculative-decoding-aware sampling, grammar-aware sampling, beam-search-aware sampling — these are SGLang-specific shapes that no upstream library will ship pre-built.

The fused sampler, as a model

# conceptually, the fused sampling step:
@triton.jit
def fused_sample(logits_ptr, mask_ptr, temp, top_p, out_ptr, ...):
    block_id = tl.program_id(0)
    logits = tl.load(logits_ptr + block_id * V + offsets)
    mask   = tl.load(mask_ptr   + block_id * V + offsets)
    logits = tl.where(mask, logits / temp, -float("inf"))   # grammar + temp
    probs  = softmax_top_p(logits, top_p)                   # top-p in registers
    token  = categorical_sample(probs, rng_state)
    tl.store(out_ptr + block_id, token)

One launch per step instead of five. The launch savings are per-step, not per-sequence — at batch 8 the unfused path still launches 5 kernels per step, just over a larger batch each. The fused win is the elimination of the 4 inter-kernel synchronization points, meaningful when a decode step is only ~2 ms to begin with.

CUDA graphs — eliminating launch overhead at small batch

Decode at small batch (say, 4 sequences) is dominated by Python and CUDA-launch overhead, not GPU compute. For Llama-70B at batch 4 on H100:

PhaseTime
Python framework overhead (per layer)~50 µs × 80 layers = 4 ms
CUDA kernel launches (per step)~3 ms
Actual kernel time on GPU~6 ms
Total per decode step~13 ms

Half of every step is wasted on launching kernels that the GPU could execute in their own time. CUDA Graphs let you record a sequence of kernel launches once and replay them from a single API call. The replay is ~10 µs of overhead total, not 3 ms.

per-launch (default) py L k py L k py L k … × 240 (3 ops × 80 layers) grey = python · red = CUDA launch · blue = actual kernel work. Two-thirds of wall-time is overhead. CUDA-graph replay L all 240 ops execute back-to-back on the GPU one launch, no per-kernel python, no per-kernel API call. Wall-time approaches actual GPU compute.

What gets graph-captured

You can only graph kernels whose shapes are fixed at capture time. SGLang captures the decode path — but decode shapes aren't fixed because batch size varies. The fix: capture multiple graphs, one per (batch_size, kv_layout) configuration the engine expects to see. At runtime, pick the matching graph. Configurations not pre-captured fall back to per-launch mode.

Prefill doesn't graph well — variable prompt lengths and ragged batches change shape too often. SGLang prefills in normal eager mode, then switches to graph for the decode loop.

The kernel choices read as a chain of consequences

  1. Workloads are dominated by attention bandwidth → tile attention (FlashAttention-style) and put KV in pages (PagedAttention-style). FlashInfer ships both fused.
  2. Decode batches are ragged → kernels take CSR offsets instead of padded tensors.
  3. Sampling, RMSNorm, RoPE are now visible on the timeline → fuse them into Triton kernels.
  4. At small batch, the timeline is dominated by Python and launch overhead → record the decode path as a CUDA graph; replay.
  5. Quantization (fp8, int4, AWQ) reduces bandwidth → push the dtypes through to the attention kernel; needs FlashInfer's quant-aware paths.

Each step is a measured response, not a stylistic choice. SGLang publishes detailed profiler traces; if you read them, the chain above is what you see.

What this looks like in the source

LayerFile / module
FlashInfer attention wrapperpython/sglang/srt/layers/attention/flashinfer_backend.py
Triton samplerpython/sglang/srt/layers/sampler.py
CUDA-graph runnerpython/sglang/srt/model_executor/cuda_graph_runner.py
Model executor / batch builderpython/sglang/srt/model_executor/model_runner.py

If you read these in that order, you can trace one decode step from a tokenized request through scheduler → batch builder → CUDA-graph replay → FlashInfer kernel call → sampled token in about an hour.

What lesson 09 builds

The kernels are fast — but a 70B model doesn't fit on one GPU and a 671B MoE model doesn't fit on one node. Lesson 09 covers the three parallelism axes (tensor, data-attention, expert) that let the same kernels run on a model that's too big for one device. Each axis interacts with a different layer of the model and a different bottleneck the kernels have not removed.

Interactive · the overhead chain at small batch

Slide batch size. Watch which overhead dominates and what each layer removes.

Decode overhead by source

At batch 1–4, launch overhead dominates and CUDA graphs are the dominant single win. At batch 32+, kernel time dominates and FlashInfer's ragged path is the win.