rl_lessons / 05 · trainer lesson 5 / 8

Trainer — where gradients flow

The smallest possible role, with the most responsibility per line of code.

What the trainer is

The trainer in this framework is the only place where .backward() is ever called. It takes a padded batch and a loss function (lesson 4 gives us the loss function), runs one forward pass through the policy, hands the resulting log-probs to the loss function, calls backward, clips gradients, and steps the optimizer. That's the whole job.

# From rl_framework/trainer.py — TrainerEngine.train_step
def train_step(self, batch, loss_fn, *, accumulate: bool = False):
    self.model.train()
    new_logp_full = per_token_logp(self.model, batch.full_ids)   # (B, T-1) — grad flows here
    loss, metrics = loss_fn(batch, new_logp_full)                # algorithm-specific
    loss.backward()
    metrics["loss"] = float(loss.detach().item())
    if not accumulate:                                           # gradient-accum: skip step
        gn = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
        metrics["grad_norm"] = float(gn)
        self.opt.step()
        self.opt.zero_grad(set_to_none=True)
    return metrics

Why the trainer is algorithm-agnostic

The trainer never imports GRPO. It never knows whether the loss it just computed is REINFORCE or GRPO or DAPO. The whole interface between trainer and algorithm is the one line

loss, metrics = loss_fn(batch, new_logp_full)

where loss_fn is the algorithm's compute_loss method, passed in as a callable. Want to swap GRPO for DAPO? Build a different Algorithm instance and pass its compute_loss instead. Trainer doesn't change. This is the most important boundary in the framework — TRL has a class hierarchy (PPOTrainer, RLOOTrainer, DPOTrainer) to express the same idea; this framework uses a function.

Why all the grad-flow lives in one file
Mixed precision, FSDP autowrap, activation checkpointing — every production accelerator wraps "the forward pass with grads." If that one forward lives in one file, you wrap it once. If the algorithm is allowed to call the model directly, you need N places to remember to apply the wrappers, and the bug pattern "fp32 leaked into the algo's forward" becomes a permanent issue. Hence: trainer forwards, algorithm scores.

Interactive · piecing the loss together

Below: a batch with two trajectories. Each row is one response token. The trainer's forward fills in the new_logp column (light blue). The algorithm uses every column to assemble the loss. Toggle a tensor to see how it contributes.

Loss assembly (one batch, one step)
Walk left-to-right: rollout filled old_logp, reference filled ref_logp, algorithm filled advantage, trainer just filled new_logp. Now compute the loss per token, then mask and average.
trajtokmask old_logp ref_logp new_logp advantage ρ pg_tok kl_tok
pg_loss (mean)
kl_loss (β·KL)
total loss
frac_clipped

One forward pass, two responsibilities

Recall lesson 2: old_logp was captured during sampling. The trainer's forward pass computes new_logp on the same tokens, but with the current (post-step) policy weights. So per token we have:

Three forward passes over (almost) the same tokens, by three different model copies. That repetition is what makes RL post-training expensive — and it's why most of the framework's complexity is about not doing more forwards than necessary (record at sampling time, batch the ref pass, share the trainer's forward across grad-accum chunks).

Gradient clipping & the AdamW choice

One detail you'll notice in trainer.py: weight_decay=0.0 and betas=(0.9, 0.95).

Gradient accumulation, briefly

Production RL runs with effective batch sizes in the thousands of trajectories. They don't fit in memory as one forward — so the trainer supports an accumulate=True mode that skips the optimizer step. Caller pattern:

for chunk in batch.chunks(microbatch):
    trainer.train_step(chunk, loss_fn, accumulate=True)   # gradients accumulate
trainer.train_step(last_chunk, loss_fn, accumulate=False) # clip + step

Same interface, lower memory. The toy in this folder fits in memory so we don't use it, but the surface is exposed.

What the trainer is not

The trainer is not:

It is one forward, one backward, one step. In a real FSDP-sharded deployment it lives on its own GPUs with its own activation memory; in this in-process toy it's a single AdamW. The interface is the same.

Takeaway
The trainer holds the only grad-flowing forward pass and the only optimizer in the system. By making the loss a callable rather than baked-in, it stays algorithm-agnostic — and every wrapper you'd ever need (autocast, FSDP, checkpoint) wraps one line of code.