Trainer — where gradients flow
The smallest possible role, with the most responsibility per line of code.
What the trainer is
The trainer in this framework is the only place where .backward() is ever called. It takes a padded batch and a loss function (lesson 4 gives us the loss function), runs one forward pass through the policy, hands the resulting log-probs to the loss function, calls backward, clips gradients, and steps the optimizer. That's the whole job.
# From rl_framework/trainer.py — TrainerEngine.train_step
def train_step(self, batch, loss_fn, *, accumulate: bool = False):
self.model.train()
new_logp_full = per_token_logp(self.model, batch.full_ids) # (B, T-1) — grad flows here
loss, metrics = loss_fn(batch, new_logp_full) # algorithm-specific
loss.backward()
metrics["loss"] = float(loss.detach().item())
if not accumulate: # gradient-accum: skip step
gn = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
metrics["grad_norm"] = float(gn)
self.opt.step()
self.opt.zero_grad(set_to_none=True)
return metrics
Why the trainer is algorithm-agnostic
The trainer never imports GRPO. It never knows whether the loss it just computed is REINFORCE or GRPO or DAPO. The whole interface between trainer and algorithm is the one line
loss, metrics = loss_fn(batch, new_logp_full)
where loss_fn is the algorithm's compute_loss method, passed in as a callable. Want to swap GRPO for DAPO? Build a different Algorithm instance and pass its compute_loss instead. Trainer doesn't change. This is the most important boundary in the framework — TRL has a class hierarchy (PPOTrainer, RLOOTrainer, DPOTrainer) to express the same idea; this framework uses a function.
Interactive · piecing the loss together
Below: a batch with two trajectories. Each row is one response token. The trainer's forward fills in the new_logp column (light blue). The algorithm uses every column to assemble the loss. Toggle a tensor to see how it contributes.
One forward pass, two responsibilities
Recall lesson 2: old_logp was captured during sampling. The trainer's forward pass computes new_logp on the same tokens, but with the current (post-step) policy weights. So per token we have:
old_logp— from rollout's forward, no gradient.new_logp— from trainer's forward, gradient flows through it.ref_logp— from reference's forward, no gradient.
Three forward passes over (almost) the same tokens, by three different model copies. That repetition is what makes RL post-training expensive — and it's why most of the framework's complexity is about not doing more forwards than necessary (record at sampling time, batch the ref pass, share the trainer's forward across grad-accum chunks).
Gradient clipping & the AdamW choice
One detail you'll notice in trainer.py: weight_decay=0.0 and betas=(0.9, 0.95).
- No weight decay. Decay slowly pulls every parameter toward zero. The KL anchor is computed in function space (over per-token distributions) — it can't see a drift toward zero in parameter space, so the anchor stops anchoring while θ silently degrades. Always set WD to 0 for RL fine-tuning unless you have a specific reason.
- β2 = 0.95 (not 0.999). RL gradients are spiky — a lucky high-reward rollout can swing the gradient hard. A lower β2 means Adam's variance estimator forgets older noise faster, so the effective learning rate doesn't get stuck after a spike. This is the GPT-2 default and is standard in RL recipes.
- Grad clip = 1.0. Belt-and-suspenders on top of PPO clipping. If the loss surprises you, the optimizer doesn't move further than length 1 in parameter space.
Gradient accumulation, briefly
Production RL runs with effective batch sizes in the thousands of trajectories. They don't fit in memory as one forward — so the trainer supports an accumulate=True mode that skips the optimizer step. Caller pattern:
for chunk in batch.chunks(microbatch):
trainer.train_step(chunk, loss_fn, accumulate=True) # gradients accumulate
trainer.train_step(last_chunk, loss_fn, accumulate=False) # clip + step
Same interface, lower memory. The toy in this folder fits in memory so we don't use it, but the surface is exposed.
What the trainer is not
The trainer is not:
- Where rewards get computed (environment).
- Where advantages get computed (algorithm).
- Where the reference is scored (reference).
- Where rollouts come from (rollout).
- Where the weights get pushed back to the rollout engine (weight-sync — next lesson).
It is one forward, one backward, one step. In a real FSDP-sharded deployment it lives on its own GPUs with its own activation memory; in this in-process toy it's a single AdamW. The interface is the same.