GRPO — drop the critic DeepSeek-Math 2024 / R1 2025
For verifiable-reward tasks, the value head can be replaced by a one-line statistic computed from the group itself. That's the entire algorithmic novelty.
What's wrong with PPO at scale
Two structural costs of the value head become painful as the policy gets bigger:
- Value head = +1 model. At 7B+ parameters, the critic roughly doubles training memory (forward + backward + optimizer state for a second copy of the trunk). On an 80GB GPU this alone can decide whether the experiment fits.
- Value head is unstable on sparse terminal rewards. For a verifier that returns 1 at the end of a 500-token response and nothing in between, Vφ has to learn "future success probability" from every position — a hard credit-assignment problem. Early in training Vφ is basically noise, which drags down PPO's gradient quality for the first few thousand steps.
GRPO's observation: for episodic, terminal-reward tasks we don't need a learned baseline at all. We can compute one on the fly from a group of rollouts.
The algorithm in two lines
For each prompt x, sample K rollouts. Compute the per-rollout advantage as
then run PPO-clip with Ai broadcast to every response token in rollout i. No value head, no MSE target, no second network. That's GRPO.
Why the group mean is a valid baseline
The policy-gradient theorem allows any function of state (= prompt x only, not action y) as a baseline:
The mean of rewards over K i.i.d. rollouts from πθ(·|x) is a Monte-Carlo estimate of 𝔼y∼πθ[R(x, y)], which depends only on x. Subtracting it removes the reward's per-prompt component — the noise without signal.
Why divide by std?
Std-normalization makes the loss invariant to reward rescaling: change your reward from {0, 1} to {0, 100} and the updates are identical modulo a constant learning rate. Stability win at no apparent cost.
The catch — which we'll see explicitly in lesson 14 — is that std-normalization amplifies gradients from low-variance groups. Low-variance groups are the easiest and hardest prompts (rewards near 0 or near 1), not the most-informative. Dr.GRPO argues this is a bias and removes it. For now, just note that GRPO keeps it and DAPO keeps it; only Dr.GRPO drops it.
Interactive · K rollouts, advantages, gradient signal
Set rewards for K=4 rollouts. The widget computes the GRPO advantages. Move sliders to find the cases where (a) all rewards are equal — the degenerate-group case from lesson 2 — and (b) one rollout is an outlier — watch its advantage dominate.
GRPO = PPO − value head
Everything in GRPO that isn't "use the group mean as the baseline" is literally PPO:
| Component | PPO | GRPO |
|---|---|---|
| Baseline | Learned Vφ(s) | Group mean of K rewards |
| Off-policy correction | ρ = πθ/πold | Same |
| Trust region | Clip ρ to [1−ε, 1+ε] + pessimistic min | Same |
| Anchor | β · KL(πθ ‖ πref) | Same |
| Multi-epoch | 4–10 SGD steps per rollout batch | Same (1 step in DeepSeek-R1, configurable) |
| Memory cost | ~2× the policy | ~1× the policy |
The cost trade is K forward passes per prompt to build the baseline, against no second network to train or memory-shard. For verifiable-reward reasoning tasks where rollout is already the dominant cost, this is dominantly favorable — which is why DeepSeek-R1 and essentially every open-source reasoning-RL pipeline since chose GRPO over PPO.
The loss in code
# From RL/algorithms/02_grpo.py — grpo_step
# (a) Group-relative advantage.
A = (rewards - rewards.mean()) / (rewards.std() + 1e-6) # (K,)
# (b) PPO-clip surrogate, with A_i broadcast to every response token.
log_ratio = new_logp - old_logp # (K, T-1)
ratio = torch.exp(log_ratio)
A_tok = A.detach().unsqueeze(-1).expand_as(ratio) # (K, T-1)
s1, s2 = ratio * A_tok, torch.clamp(ratio, 1-ε, 1+ε) * A_tok
pg_tok = -torch.min(s1, s2) * target_mask
# (c) Sequence-level aggregation (original GRPO): mean over tokens within
# a rollout, then mean over K. We'll meet two alternatives in lessons 13–14.
pg_loss = (pg_tok.sum(-1) / length_per).mean()
# (d) KL anchor, same aggregation.
kl_loss = β * (compute_kl_k3(new_logp, ref_logp) * target_mask).sum(-1) / length_per
loss = pg_loss + kl_loss.mean()