Long-tail rollouts — the max-of-K problem, packing, and dynamic K
The throughput equation (lesson 22a) said rollout dominates at 60–80% of wall-clock. This lesson explains why most of that isn't raw decode — it's the tail. A small fraction of trajectories generate most of the wall-clock, and three patches — sequence packing, length capping, dynamic K — collectively cut τ_R by 2–4× without changing a single FLOP.
The first-principles observation: rollout latency is the max, not the sum
Per RL step you sample B · K trajectories — one prompt batched K ways for group-baseline algorithms (lessons 11–14). The trainer cannot start until every trajectory finishes, because the loss touches all K rollouts of each prompt to compute the group mean. τ_R is the time until the slowest trajectory completes, not the average:
If trajectories had identical length, this would collapse to mean × (steps to decode), and the analysis would end here. They don't. Two structural reasons:
- Open-ended generation. Each trajectory stops at the first EOS the model samples, or at a hard length cap. With temperature > 0, the same prompt produces trajectories with wildly different lengths: a reasoning model might solve an AIME problem in 500 tokens on one rollout and 5,000 tokens on another.
- Heavy-tailed length distributions. Empirically, response length distributions look approximately log-normal — heavy-tailed, not Gaussian. A small fraction of trajectories (~5%) account for a disproportionate fraction (~30–50%) of total decoded tokens.
The max-of-K math, in one line
How does τ_R grow with K when lengths are heavy-tailed? Take lengths i.i.d. log-normal with mean μ and coefficient of variation c = σ_L / μ. On the log scale, lengths are Gaussian with standard deviation s = √(ln(1 + c²)). The typical maximum of K samples follows the extreme-value approximation:
The exponential's argument grows like s · √(ln K) — slow in K, fast in the tail width s. Numbers from the formula:
| K | c = 0.5 (mild) | c = 1.0 (typical reasoning) | c = 2.0 (open-ended) |
|---|---|---|---|
| 4 | 2.0 μ | 2.8 μ | 3.7 μ |
| 16 | 2.7 μ | 5.0 μ | 8.9 μ |
| 64 | 3.5 μ | 7.8 μ | 17.4 μ |
| 256 | 4.3 μ | 11.3 μ | 30.6 μ |
Two consequences worth pinning to a wall:
- K=16, typical reasoning: τ_R is ~5× the mean trajectory length, not the mean. If your mean is 500 tokens, you wait for trajectories of ~2500 tokens. That gap is straggler tax.
- The growth is sub-logarithmic in K but super-linear in c: doubling K from 16 to 32 adds about 15%; doubling c from 1.0 to 2.0 roughly doubles the max. K is not the primary lever — the tail width is.
The three patches, in order of leverage
Patch 1 — Sequence packing
The original sin of naive batched inference is padding to the longest sequence in the batch. If your batch has lengths [120, 300, 5000, 80, ...], every token-step computes attention over the 5000-position pad for the short sequences. With paged KV (lesson 20) and continuous batching (lesson 21), the rollout engine never pads — it manages each sequence's KV blocks independently and only attends over real tokens. The trainer needs its own version of this:
- Variable-length attention. FlashAttention's
varlenkernel takes a "cu_seqlens" cumulative-length tensor and attends only within each segment. No padding mask, no wasted FLOPs. - Concatenation in the batch dim. Pack multiple short trajectories into one "long" sequence with the boundary explicit in cu_seqlens. The forward and backward see one batch of size 1 with sequence length sum(L_i), which is dense.
- Loss masking. Each token's contribution is scaled by 1/L_i (its trajectory's length) so the loss is per-trajectory, not per-token. This is the "token-level vs sequence-level loss" choice DAPO addresses (lesson 13).
Sequence packing for the trainer typically buys 1.5–3× throughput on the trainer side. It does not help τ_R directly — the rollout engine already doesn't pad — but it makes the trainer side fast enough that disaggregated topology becomes worth running.
Patch 2 — Length cap with shaped penalty
The simplest patch and the highest leverage. Set a hard cap L_max; any trajectory that reaches it without emitting EOS is truncated and its reward shaped down (DAPO calls this "overlong soft penalty" — lesson 13).
Two design choices that distinguish a good cap from a leaky one:
- Where to draw the line. Too tight and you're capping rollouts that would have succeeded — your gradient signal degrades because the model never gets to learn long reasoning chains. Too loose and the tail stays expensive. Rule of thumb: pick L_max at the 95th-to-98th percentile of successful rollouts in your current pool, not all rollouts.
- How to shape the reward of capped rollouts. Three options: (a) hard zero (the trajectory is treated as a failure), (b) soft penalty (reward scales down from 1.0 between L_p95 and L_max), (c) ignore (mask out the trajectory from the gradient). DAPO favors (b); it preserves gradient on long-but-successful rollouts while penalizing runaway ones. Hard zero is dangerous: it teaches the model that "I'm about to hit the cap" is an absorbing failure, which actively encourages early termination on hard prompts.
Empirically a well-tuned cap cuts τ_R by 30–50% with negligible accuracy cost on verifiable tasks.
Patch 3 — Dynamic K and oversampling
The other axis. Standard RL fixes K (say, K=16) and waits for all K trajectories per prompt. The DAPO observation (lesson 13) is that some prompts produce all-equal rewards anyway — those groups contribute zero gradient regardless of how many rollouts you collected. Dynamic sampling oversamples prompts speculatively and drops degenerate groups before they enter the loss.
Two flavors:
- Speculative oversample. Launch K' > K rollouts per prompt; as rewards come back, keep the first K with mixed outcomes per prompt and abort the rest. The straggler trajectories that haven't finished yet are killed — wasted decode, but wasted decode that wasn't going to matter anyway because the group was already settled.
- Early-termination per group. Once K rollouts of a prompt have come back and at least one is a success and one a failure, kill the remaining rollouts of that prompt — the group baseline is already estimable from the K you have.
The trade is wasted decode FLOPs for wall-clock. On a cluster where rollout dominates τ_step, that trade is almost always worth it — you're throwing away cheap stragglers to claim expensive idle time on the trainer side.
Interactive: straggler tax simulator
The widget simulates one rollout step of K trajectories drawn from a log-normal length distribution. Pick K, mean length, coefficient of variation, and which patches are enabled. The plot shows individual trajectory durations as bars; the KPIs report τ_R, mean utilization, and the "straggler tax" (max / mean − 1).
Sequence packing — the rollout side and the trainer side
Packing means different things on each side of the disaggregated topology, and conflating them is a common source of pipeline-design confusion.
| Side | What "packing" means | Kernel | Throughput win |
|---|---|---|---|
| Rollout (inference) | Continuous batching: as soon as a sequence finishes, start a new one in its KV slot. No padding ever exists; the batch is always full. | PagedAttention + scheduler (lesson 21) | 2–20× vs static batching |
| Trainer (forward+backward) | Concatenate variable-length trajectories into one sequence with cu_seqlens; FlashAttention varlen computes block-diagonal attention without padding waste. | FlashAttention varlen, chunked CE | 1.5–3× vs padded batch |
One non-obvious interaction: trainer-side packing changes the gradient's "natural unit" from "per padded slot" to "per trajectory." The loss must be normalized per-trajectory, not per-token, or longer trajectories will dominate the gradient (length bias — exactly what DAPO's token-level loss patches). Sequence packing without per-trajectory normalization is a silent footgun: throughput goes up, but the algorithm subtly changes.
Dynamic K and the kept-fraction tradeoff
The exact stat to log when you run dynamic K is the kept fraction: k = (rollouts that reached the loss) / (rollouts that ran decode). Three regimes:
- k ≈ 1.0: no dynamic K active. Every rollout contributes — but you're paying for degenerate groups. Useful when the policy is in a regime where most groups have mixed outcomes anyway.
- k ∈ [0.5, 0.95]: dynamic K is working — you're trading wasted decode for lower τ_R. This is the operating regime for a healthy RL run mid-training.
- k < 0.5: more than half of your decoded tokens never make it to the loss. Either the policy is in a bad pass-rate regime (data pipeline issue — lesson 18a), or the dynamic-K thresholds are too aggressive.
One way to think about it: the data pipeline (lesson 18a) keeps the offline kept-fraction high by stratifying prompts; dynamic K keeps the online kept-fraction high by killing groups that didn't pan out. The two patches compose — applying both gets you closer to a regime where every decoded token is useful.
FP8 / mixed-precision asymmetry
One last lever specific to the rollout side. The trainer and the rollout engine don't need the same precision:
- Trainer: BF16 weights with FP32 master + optimizer state. You can't go below BF16 for the backward without numerical instability on most pretraining-derived models.
- Rollout: FP8 weights and KV are now standard on H100+. Decode is memory-bandwidth-bound; FP8 halves the bytes per token and ~doubles tok/s. The cost is a small KL drift between the trainer's policy (BF16) and the rollout's policy (FP8) — which the importance ratio in PPO/GRPO is designed to correct, up to ~5–10% per-token KL.
old_logp on the trainer side after the rollout returns (the strict option that adds 10–15% to step time); the kernel-parity alternative — match the trainer's and rollout's numerics carefully enough that the gap is negligible — also works but is harder to maintain. The throughput benefit of FP8 rollout only holds if one of these is engineered in; otherwise you will see clip-fraction climbing and gradient quality dropping, with no obvious cause.
Putting the patches together
For a representative 7B run with K=16, mean L=800, CV≈1.0, the patches compose roughly as follows. Each row applies the row above plus the new patch.
| Configuration | τ_R relative | What changed |
|---|---|---|
| Baseline (static pad, no cap, fixed K) | 1.00× | — |
| + paged KV + continuous batch | 0.30× | no padding waste; reuse slots |
| + length cap at p95 with soft penalty | 0.18× | tail truncated, gradient preserved on success |
| + dynamic K (kill on settled groups) | 0.13× | online filter of degenerate groups |
| + FP8 rollout w/ log-prob recompute | 0.08× | halve bytes/token on decode |
An order of magnitude on the dominant term of τ_step, without changing the algorithm or the model. This is what "throughput optimization" looks like on the rollout side.