RL / lessons / 25 · kernels for RL lesson 2 / 2 · part IV

Kernels for RL — what an RL infra engineer actually writes

"New kernels for rollout and training" is shorthand for six distinct kernel surfaces, each with its own bottleneck physics. This lesson maps every surface, derives why it matters, and names the kernels that pay back for the engineer who writes them.

The first-principles map

RL post-training is forward-heavy. From lesson 24: a single step does ~1M forward passes for rollout, one forward + backward for training, one forward for the reference model, and a parameter broadcast for weight sync. Each of these is a distinct kernel optimization surface. Listing them honestly:

#SurfaceWorkload shapeCompute / memory regime
1Rollout / inferenceAutoregressive decode + occasional prefillDecode is HBM-bound (load full weights + KV per token). Prefill is compute-bound.
2Log-prob matchingSingle forward over completed trajectoriesBandwidth on LM head; numerics on softmax.
3Training (policy update)Forward + backward over packed trajectoriesFFN/attention matmuls (compute-bound), CE loss (memory-bound), KL/PPO surrogate.
4Weight syncBroadcast / cast / reshape of parametersNetwork bandwidth (NCCL) + HBM bandwidth for cast.
5Memory & schedulingKV cache allocator, activation offload, gradient accumulationHBM allocator + PCIe overlap.
6Agentic primitivesMulti-turn KV reuse, varlen masking, tool-output maskingMixed: kernel + control flow.

Now drill into each.

Surface 1 — rollout / inference kernels (the ~60% of step time)

Decode is memory-bound: one token at a time, you load the entire model's weights from HBM, plus the entire KV cache so far, to emit one new token. The compute is trivial; the bandwidth dominates. So all rollout kernel work targets either reducing bytes moved or hiding the cost.

KernelWhat it doesWhy it's an RL engineer's problem
FlashAttention 2/3, FlashInfer Tiled attention with online softmax — O(N) HBM traffic instead of O(N²). FA3 adds asynchronous Hopper TMA pipelining. You'll choose the version, port between vLLM and SGLang, and patch when sliding-window or per-head dropout breaks something. gpu_kernel_serving/12.
PagedAttention KV cache stored as paged blocks (like OS virtual memory) so non-contiguous KV reads work and fragmentation drops to ~zero. Block table in HBM. Variable-length trajectories in RL fragment a contiguous KV allocator badly. PagedAttention is non-optional once you have >2 concurrent rollouts.
Prefix caching / RadixAttention Detect shared prefixes across requests (system prompt, few-shot, task header) and reuse the KV cache. RadixAttention organizes prefixes in a radix tree. The single highest-leverage RL-specific kernel. K-rollout sampling means K trajectories share the same prompt prefix. Without prefix caching: K full prefills. With: 1 prefill + K decodes. Speedup is 3–10× on math-style benchmarks.
Continuous batching Dynamically add / remove sequences from the running batch each step, so completed sequences don't waste compute on padding. RL trajectories have wildly different lengths (200 to 10000 tokens). Static batching wastes 50%+ of compute.
Chunked prefill Split a long prefill into chunks, interleaved with decode steps from other sequences. Reduces decode TTFT latency. In K-rollouts-per-prompt RL, you want decode and prefill from different rollouts to overlap. Chunked prefill makes the scheduler symmetric.
GQA / MQA attention Multiple query heads share K, V → smaller KV cache, smaller HBM read per decode step. You don't write the kernel often (models ship with their head layout), but you'll choose GQA-aware kernels and verify the inference engine respects the K-head grouping correctly.
Speculative decoding (EAGLE-2/3, Medusa, lookahead) A draft proposer (small model, or a learned head like EAGLE/Medusa) proposes K tokens; the target verifies all K in one forward pass; accepted tokens emit "for free". vllm/07. For greedy or low-T rollouts, ~2× throughput. For high-T exploration (typical RL), naive draft-model speculation degrades; EAGLE-style proposers (which condition on target's hidden state) hold up better because they bias toward what the target would sample.
KV cache quantization (int8 / fp8) Quantize K, V tensors in the cache → half the HBM traffic per decode token. Decode is HBM-bound — int8 KV cache is a near-2× decode speedup. Affects log-prob match: if rollout uses int8 KV but trainer uses bf16, the logprobs differ. You decide whether to mirror in trainer or recompute.
CUDA-graph re-capture under shape change CUDA graphs themselves are table-stakes in 2026 — vLLM, SGLang, TensorRT-LLM all default-on them. The engineering work is invalidating & re-capturing graphs when batch shape, KV layout, or paged-block count change — which happens constantly under continuous batching. Without careful re-capture, you either re-graph every step (kills the win) or fall off the graph path (kills the win the other way). Production frameworks have hundreds of cached graph variants keyed on shape.
Fused sampling top-k / top-p / temperature / penalties fused into one kernel after the logits step. For RL, you often sample with temperature ≠ 1 and need exact reproducibility. A fused sampler that uses a deterministic Philox RNG is necessary for replay-debug.
The first kernel you write for RL
Almost always: prefix-aware paged attention. The math: K rollouts share a P-token prompt; without sharing, K·P prefill work; with sharing, P prefill + K·G decode work (G = generation length). The ratio is K · P / (P + K·G) = K / (1 + K·G/P) — for K=16 and G/P=0.1, the prefill savings are ~6×; for K=64, ~8.6×. (Total step-time wins are smaller because decode is unchanged, but prefill is the compute-bound phase, so this lands as 3–10× rollout speedup depending on the regime.) SGLang's RadixAttention handles sub-block prefixes via a radix tree; vLLM's APC (Automatic Prefix Caching) is a block-trie variant that's coarser-grained. On a fresh framework you integrate one — you do not write paged attention from scratch in 2026.

Surface 2 — log-prob matching kernels (the silent bug surface)

Recall the PPO surrogate:

L_PPO = E[ min( ρ · A, clip(ρ, 1−ε, 1+ε) · A ) ] ρ = π_θ(a|s) / π_old(a|s)    — importance ratio

The ratio's numerator is the trainer's current policy log-prob. The denominator is "the policy that produced the rollout", i.e., the rollout engine's log-prob at sample time. If those two engines compute slightly different logprobs for the same (model, token sequence), ρ is biased — even at the moment of the rollout.

Where mismatch comes fromNumerical magnitude
Different attention kernels (FlashAttention vs vLLM's paged variant)~1e-4 in logprob ⟶ ρ ≈ 1.0001 per token. PPO is per-token, so the failure mode isn't a multiplicative blow-up across the sequence; it's that a slowly accumulating systematic bias pushes more and more per-token ratios outside the clip range [1−ε, 1+ε], zeroing gradients on those tokens and biasing the surviving ones.
Mixed precision: rollout in fp16 / int8, trainer in bf16~1e-3 in logprob → catastrophic over long sequences.
Different reduction order in softmax~1e-6, usually OK.
Tokenizer asymmetry (extra BOS, different whitespace)Categorical — wrong tokens entirely.
Sampling temperature applied differentlylogprob computed at the post-temperature distribution must match.

So the engineer's responsibilities here:

  1. Decide the match policy. Two choices:
    • Strict-recompute: have the trainer recompute logprobs for the rollout sequences. Adds one trainer forward per step (10–15% time cost). Simple, correct.
    • Numerically-equivalent engines: guarantee the rollout and trainer produce byte-identical logprobs. No extra forward, but every engine change is a potential regression.
    Modern frameworks (verl, OpenRLHF, TRL) default to strict-recompute. Frontier labs sometimes use equivalent engines for the throughput.
  2. Write the recompute kernel. It's a packed-sequence forward over completed trajectories, plus a gather to select the sampled token's logit, plus log_softmax. Standard shape:
    logits[B, T, V] → gather(token_ids[B, T]) → log_softmax → logπ[B, T]
    The naive implementation materialises the (B, T, V) logits tensor. Packed across a step (B·T = 32k tokens, V=128k, bf16): ~8 GB just for logits, plus another 8 GB for softmax intermediates, plus the full backward-pass gradients on that tensor. The peak hits ≥ 30 GB on a 32k packed batch — competing for the same HBM as the rest of the activation set. Hence:
  3. Chunked / fused log-prob kernel. Compute log_softmax over chunks of V and reduce to logπ[B, T] without ever materialising the full logits tensor. Hugging Face's liger-kernel, vLLM's flash_attention_with_logits, and Megatron's vocab_parallel_cross_entropy are different production answers. Writing this kernel saves ~30–50% of trainer-forward memory.
  4. Reference-model logprob fusion. The KL anchor needs logπ_ref. If ref runs in the same forward sweep (shared model on colocated GPUs), you fuse a "gather both ref and policy logπ in one pass". If ref runs separately, you orchestrate the two streams' outputs into the loss kernel.
The bug worth a week
Rollout engine computed logprobs with sampling temperature applied (so logπ = log softmax(z/T)); trainer recomputed at temperature 1 (logπ = log softmax(z)). Importance ratio looked fine on average but was systematically biased high for tokens far from the mode. Symptom: clipping fraction climbing slowly; rewards flat. Fix: 1 line. Cost: 6 days of training thrown away. Always log ρ's histogram per step.

Surface 3 — training kernels (policy update)

Once you have packed trajectories, advantages, and logprobs, the actual policy-update step is "forward + loss + backward". Three RL-specific kernel needs:

KernelWhat it doesWin
Sequence-packed varlen attention Pack many variable-length trajectories into a single (B=1, T_total) tensor with a cu_seqlens table, then run FlashAttention with the varlen flag. 30–60% fewer FLOPs vs padded batches (real ratio depends on length variance).
Chunked cross-entropy / Liger CE Compute log_softmax + gather + NLL in chunks over the vocab dimension. Memory: O(B·T) instead of O(B·T·V). Reduces peak memory by ~30–50%. Enables larger sequence lengths.
Fused PPO ratio + clip One kernel: ρ = exp(logπ − logπ_old); clipped_ρ = clamp(ρ, 1−ε, 1+ε); loss = -min(ρ·A, clipped_ρ·A); reduce. Saves a few intermediate tensors and one HBM round-trip per token. ~5% step time.
Fused KL estimator One kernel for the k3 estimator: kl = exp(logπ_ref − logπ) − (logπ_ref − logπ) − 1. Combined with policy gradient as a single backward target. Removes a separate forward over reference logprobs. Numerical stability: k3 stays positive by construction.
Loss-mask aware reduction Reduce the per-token loss with a mask that zeros prompt and tool-output positions before any normalization (Dr.GRPO style). Correctness — bias-free per-token average over only model-generated tokens.
Activation checkpointing decision For long trajectories, decide which activations to save (attention output, FFN intermediate) vs recompute. Selective checkpointing of attention only saves ~60% activation memory at ~10% extra compute. Enables longer trajectories at same memory budget. Key for agentic RL (10k-token trajectories).
Grouped advantage normalization For GRPO/RLOO: compute group statistics (mean, std) of advantages per prompt, then normalize. Custom kernel because the grouping is data-dependent. Numerical stability + correct gradient through normalization (or stop-grad through the stats).
The trainer's heaviest line
For a typical RL step, the single most memory-expensive line is the language-model head's cross-entropy. With vocab=128k, packed sequence T=32k, bf16: 8 GB of logits before any activation. Chunked CE (process 1k vocab tokens at a time, accumulating loss + grad scalar) brings this to ~64 MB. Every RL framework worth using has this kernel. If yours doesn't, write it first.

Surface 4 — weight-sync kernels (the often-forgotten one)

After every training step, new policy weights must move to the rollout engine. The kernels here are unglamorous but their cost is real.

KernelWhat it doesCost / win
NCCL broadcast / all-gather Move parameter shards (FSDP-sharded fp32 master → bf16) from the trainer rank to all rollout ranks. Network-bandwidth bound. For 70B bf16: 140 GB over NVLink (≤ 1s) or IB (~5–20s depending on fabric).
In-place dequant / cast Convert master fp32 weights to the inference dtype (bf16, fp8, int8) without a roundtrip to HBM. Saves 1× HBM read/write of the whole model. For 70B: ~140 GB · (1/3.35 TB/s) ≈ 40 ms. Per sync.
FSDP-shard → TP-shard layout conversion Trainer shards along FSDP's row-wise dim; rollout uses TP shards along a different axis. Kernel reshuffles tensors during the broadcast. Avoids an explicit gather + scatter. Fuses with the broadcast collective.
IPC weight handoff (colocated) In colocated mode, trainer and rollout share GPU memory. The "sync" is a CUDA IPC handle exchange + pointer swap. ~zero copy. Synchronization overhead only.
Sharded checkpoint write/read For async architectures: write a sharded checkpoint to a fast tier (e.g. shared NVMe), rollout reads it asynchronously. Decouples sync from rollout step boundary.
Overlap of sync with next-step rollout While the next rollout begins on stale weights, sync the new weights in a background CUDA stream. Swap mid-step at a layer boundary. Hides sync entirely. Requires extreme care: a mid-step swap mid-forward gives inconsistent activations.
The sync regression
Frameworks often start with a "stop everything, broadcast, resume" sync. Throughput is fine at small scale, terrible at multi-node. The path to async is: (1) shard the broadcast (now it's an all-gather, not a serial broadcast), (2) fuse the dtype cast, (3) overlap with the next rollout's prefill. Each step is a kernel + scheduler change worth weeks. Many production frameworks haven't done (3).

Surface 5 — memory & scheduling kernels

RL has unique memory pressure because it juggles multiple roles (rollout, trainer, ref) on the same hardware. Several kernels live in the allocator / scheduler layer rather than in the model.

The allocator quietly dominates
Most rollout slowdowns reported as "the engine is slow" are actually KV allocator pressure: fragmentation, repeated alloc/free, or page-table cache misses. The first 30 minutes of any rollout perf investigation should be in the allocator: nvidia-smi --query-gpu=memory.free over time, plus the engine's KV page stats.

Surface 6 — agentic kernels (the new frontier)

Multi-turn, tool-using, branchy RL has spawned a new class of kernels that didn't exist for single-turn pretraining or even single-turn RLHF. The engineer building agentic RL infra writes these:

Kernel / primitiveWhat it does
Multi-turn KV cache reuse Within a trajectory, keep the KV cache alive across turns; append tool-output tokens to the cache without re-prefilling. The cache layout must respect the agent's interleaving of assistant / user / tool tokens.
Tool-output mask construction Per-token mask (1 for assistant-generated, 0 for everything else) computed alongside generation. Must be byte-identical between rollout and trainer to avoid the loss-mask bug from lesson 24.
Branchy rollout scheduling K-sampling at intermediate turns (e.g. self-consistency or PRM scoring) requires forking KV caches at a turn boundary. The "fork" is a refcount bump on the shared KV blocks (in SGLang's radix tree implementation) — physically copied only on write to the diverging suffix. The kernel work is the bookkeeping, not the data movement.
Async tool dispatch Issue tool calls (code execution, search) on a CPU pool while GPU continues generating other trajectories. The kernel is the queue + completion handler; the GPU never blocks on a tool.
Per-turn reward broadcasting For dense rewards (PRM-style), broadcast per-turn rewards across the corresponding tokens with the right attribution. The kernel is a scatter from (B, num_turns) to (B, T) using turn-boundary indices.
Per-trajectory tensor co-packing (Same kernel as surface 3's varlen packing, but with the extra constraint that loss mask, advantage tensor, KL tensor, and reward tensor all share the trajectory layout. The bookkeeping is where bugs live: an off-by-one between the cu_seqlens table and the advantage tensor silently shifts gradient to the wrong tokens.)
Verifier dispatcher Batched evaluation of verifiers (math grader, code runner) on GPU/CPU pool. Often the longest tail in step time; needs care to not block rollout.

How a real RL engineer prioritizes

A first-week heuristic, in order of return-on-investment:

  1. Profile first. Use the per-component breakdown from lesson 24. Anything that's not 60–75% rollout means the profiler is wrong or your setup is unusual.
  2. Prefix caching. If you don't have it, integrate it. Single biggest win for K-rollout RL.
  3. Continuous batching + chunked prefill. Pad-free, decode-prefill mix. Easy 2× rollout throughput.
  4. Chunked CE on the trainer. Memory-only win; enables longer sequences.
  5. Log-prob match. Audit, fix, monitor. Save weeks of training before you save days of throughput.
  6. Speculative decoding for low-temperature rollouts. Conditional 1.5–2× decode.
  7. KV cache quantization (int8). 1.5–2× decode if your log-prob match story can tolerate it.
  8. Async weight sync overlap. Hide sync latency; multi-node win only.
  9. Custom fused PPO / KL kernel. Marginal; do last.
What "develop new kernels" really means in 2026
Almost nobody writes a from-scratch attention kernel anymore. The 2026 reality of "developing new kernels" is: The job title "kernel engineer" in 2026 is closer to "kernel composer + perf engineer" than to "CUTLASS author".

The kernels we'll see next

RL-specific kernel R&D in 2026 is concentrated around four areas:

  1. Cache-aware speculative decoding for agentic rollouts. Standard speculative decoding doesn't compose with mid-trajectory tool calls. New work makes the draft proposer tool-aware.
  2. Selective recompute for very long trajectories. Activation memory dominates for 32k+ token trajectories; finer-grained recompute strategies (per-operation, not per-layer) are emerging.
  3. Cross-engine numerics standardization. A push to make rollout engines and trainers produce byte-identical logprobs, eliminating the strict-recompute tax. Requires shared kernel implementations or a verified numerical contract.
  4. MoE-routing kernels for RL. Sparse MoE models (DeepSeek-V3, Mixtral, Qwen-MoE) need router + AllToAll dispatch on both rollout and trainer sides, with the same load-balance behaviour. RL adds two new headaches: (a) the load-balance auxiliary loss interacts with the policy-gradient loss in non-obvious ways; (b) async setups can see the router weights drift across engines, biasing which experts a token reaches. Active area of frontier-lab engineering.

Interview prompts you should be ready for

  1. "You're given a 7B RL run with 60% rollout / 15% train forward / 10% ref / 10% backward / 5% sync. Where do you start?" (Rollout. Specifically: is prefix caching on? If yes, is K-rollout sharing being detected? Then continuous batching status, then KV-cache fragmentation. Don't touch the train kernels until rollout is at theoretical bandwidth limit.)
  2. "Chunked CE — derive the memory saving." (Naive CE materialises (B, T, V) logits and (B, T, V) softmax intermediates. Chunked CE computes log_softmax + gather + NLL on chunks of V, keeping only (B, T, V_chunk). For V=128k, V_chunk=4k: 32× peak memory reduction on logits — typically 50%+ peak trainer-forward memory.)
  3. "Walk through the log-prob match problem and how you'd fix it." (Rollout and trainer compute logπ on the same model but with different kernels → small numerical drift → biased PPO ratio. Two fixes: strict-recompute on trainer (default; adds 10–15% step time) or kernel parity guarantee (hard). Always log ρ histogram per step.)
  4. "What's the difference between paged attention and radix attention?" (Paged: KV stored in non-contiguous blocks via a block table — for variable-length / preempted sequences. Radix: a radix tree on top of the KV blocks indexes shared prefixes across requests — for K-rollout sampling. The two are orthogonal; SGLang combines them.)
  5. "Speculative decoding — why does it not always help in RL?" (Acceptance rate depends on the draft model agreeing with the target on each token. At high sampling temperature (typical RL exploration), the target's next-token distribution is wide, so acceptance is low and the K-token-verify overhead doesn't pay back. Helps most at T ≤ 0.5; barely helps at T = 1.)
  6. "You see KV cache fragmentation in production. What do you do?" (Either: implement paged allocator if not yet, or compact (move blocks to defragment, expensive), or restart the engine on a schedule. Production fix is paged; restart is the band-aid.)
  7. "Design a kernel to compute the PPO loss + KL anchor in one fused pass." (Inputs: logπ[B,T], logπ_old[B,T], logπ_ref[B,T], advantage[B,T], mask[B,T], clip_eps. One pass: ρ = exp(logπ-logπ_old); kl = exp(logπ_ref-logπ) - (logπ_ref-logπ) - 1; ppo_loss = -min(ρ·A, clamp(ρ,1-eps,1+eps)·A); total = (ppo_loss + β·kl) · mask; reduce. Backward: chain-rule through exp and clamp; clamp's grad is 0 outside the clip range.)
Takeaway — the two duties, reframed
"Improve RL infra for agentic use" and "develop new kernels for rollout + training" are both core duties. They cover surfaces 1, 3, and 6 above. The complete kernel surface area for a modern post-train RL engineer is six surfaces: rollout/inference, log-prob match, training, weight sync, memory/scheduling, agentic primitives. The first kernel to write is almost always prefix caching. The second is chunked cross-entropy. The hardest one to get right is the log-prob match. The job is not "writing kernels in isolation" — it's identifying which kernel surface dominates this week's profile and shipping a composition or fusion that moves the needle. Kernels are the last 30% of speed but the most measurable; they're how individual engineers are visibly fast.