gpt_mini / lessons / 04 · cot lesson 4 / 6

Chain-of-thought: a data change, not a loss change

Why a fixed-depth transformer can solve K-step problems only when it gets to write K tokens to think with.

The setup

Lesson 3 ended with a working SFT pipeline: prompt-masked cross-entropy, a chat template, and a held-out accuracy number that says the model imitates the gold response. This lesson keeps every line of that training code identical. The optimizer is the same. The loss is the same. The masking is the same. The chat template is the same. The model architecture — d=128, h=4, L=4 — is the same.

What changes is one field of one dataset class: the string we ask the model to predict. Where SFT produced response text like "18#", CoT produces response text like "{3+4=7;7+5=12;12+6=18}18#". The reasoning trace lives between two new sentinel tokens, { and }, and it is structured as a chain of single-step running totals: acc+next=new, separated by ;. The final integer after } is the answer; the # still marks end of turn.

The toy task is variable-length addition of single digits. A problem is a sequence of 2 to 6 digits drawn uniformly from 0..9; the prompt encodes the sum, and the response is either the answer alone (the direct regime) or a fully expanded running-total trace followed by the answer (the CoT regime). Two models, identical in every respect except the training-data string, give us a controlled experiment for what reasoning tokens actually buy.

Chat / reasoning template (single-char sentinels)

A worked example for the problem 3+4+5+6:

< 3 + 4 + 5 + 6 > { 3 + 4 = 7 ; 7 + 5 = 1 2 ; 1 2 + 6 = 1 8 } 1 8 #

The dim tokens are the prompt (no loss). The bright tokens are the response (loss everywhere). The reasoning block and the final answer are both response; from the model's point of view they are indistinguishable — both contribute equal cross-entropy gradient at training time, both are decoded one token at a time at inference. That is the entire architectural delta from SFT. There is no separate “reasoning head”, no auxiliary loss, no `cot_weight=0.5` hyperparameter, no second forward pass on a planning model. The reasoning tokens are just response tokens whose characters happen to look like arithmetic.

The first-principles argument: serial compute per forward pass

This is the heart of the lesson. To understand why writing the reasoning out as tokens helps a model that already has all the arithmetic facts in its weights, you have to be precise about what a single forward pass of a fixed-depth transformer can and cannot do.

A transformer of depth L processes its input through L identical blocks. Each block does, in order: layer-norm, self-attention, a residual add, layer-norm, an MLP, another residual add. That is a fixed sequence of matrix multiplications, with no recurrence between layers other than the residual stream. Critically, within a single forward pass at one token position, the depth of the computation that produces the output logits is O(L). There is no “loop until done” mechanism. Width helps for things you can compute in parallel; depth is the budget for things you must compute in sequence.

Now consider summing K single-digit numbers. The natural algorithm — left-fold over the running total — requires K-1 sequential additions. Each addition needs the result of the previous one as input. You cannot vectorize across them; acci literally depends on acci-1. With L=4 layers and K=6 summands, the model has at most four serial “compute slots” in which to perform five serial adds. The pigeonhole math says the model cannot, in one forward pass, execute the canonical algorithm.

This is not a complaint about parameter count, dataset size, or training procedure. It is a statement about the function class. A transformer of fixed depth, viewed as a circuit, is parallel-bounded by O(L) sequential gate steps no matter how wide the layers are or how cleverly you initialize them. Compositions of K steps, where the i-th step depends inherently on the (i-1)-th, have no O(L) implementation for K > L.

One forward pass, K=3 summands, L=4 layers — fits comfortably Layer 1 read & embed slot 1 used Layer 2 add 1: a+b slot 2 used Layer 3 add 2: +c slot 3 used Layer 4 project to logits slot 4 free Output answer

For K=3, the model needs two serial adds plus an embed/decode envelope. There is room. The picture is comfortable. A model with L=4 can learn this and will report 90–100% accuracy on it.

For K=8, the picture breaks. Seven serial adds do not fit in four layers. The model is forced to do something else: approximate. The most natural approximation is to memorize a lookup table in the MLP weights — a layer's MLP can implement up-to-quadratic-in-width key-value retrieval — mapping the entire input tuple (d1, d2, …, dK) directly to the sum. For K=2, the table has 100 entries; trivial to memorize. For K=6, the table has 106 entries; on a 128-dimensional model trained for 2000 steps, you do not fit that table. Accuracy degrades, not gracefully but sharply, because the failure mode is “the model never saw this exact input combination during training” rather than “the model's algorithm got slightly off”.

One forward pass, K=8 summands, L=4 layers — overflows by 3 serial adds Layer 1 embed Layer 2 add 1…2 Layer 3 add 3…4 Layer 4 project overflow adds 5,6,7 unfit

This is the surprising fact that makes CoT a genuine architectural primitive rather than a prompting trick. The serial-compute ceiling is real and tight, and the gap between “the model knows single-digit addition cold” and “the model can compose 7 of them in one forward pass” is a gap of budget, not of knowledge.

The workaround: emit intermediate results as tokens

Here is the move. A response of R tokens is not one forward pass — it is R forward passes, one per token, each conditioned on a growing prefix. The i-th response token is produced by a fresh end-to-end pass through the network, with the previous i-1 response tokens visible in the context. That is R \cdot L sequential matmul-then-LN-then-MLP operations across the response, not L.

If we structure the response so that each token's job is one bounded-difficulty subproblem — one single-digit addition, say — the per-token serial-compute requirement stays well below L, and the total serial work scales with the response length, which the model controls. We have converted a hard ceiling on serial compute (set at train time by the architect's choice of L) into a soft budget on serial compute (set at inference time by however many tokens the model chooses to emit).

K=8 with CoT — one forward pass per reasoning step, each fits trivially step 1 a+b=c step 2 c+d=e step 3 e+f=g step 4 g+h=i step 5 i+j=k step 6 k+l=m step 7 m+n=p ans = p Each block is a full L=4 forward pass. Total serial budget = 8 × L = 32 layer-ops. Per-step requirement = O(1) serial adds. Budget exceeds need by 4×. The fixed-depth limit is gone — not because the model is bigger, but because the model gets to use more passes.

The key insight is that the new compute is allocated at inference, and the model's training data is what teaches it to allocate that compute productively. Without CoT data, even a model that has the capability to write reasoning tokens has no reason to: it has only ever been shown direct answers, so the conditional distribution it learned puts almost all its probability mass on emitting the answer immediately. With CoT data, the conditional distribution puts almost all its probability mass on emitting { first, then a step, then another step, then }, then the answer.

Two factorizations side by side

Probabilistically, the contrast is:

no CoT:   Pθ(ans | problem)     (response length R = |ans|)
with CoT:   Pθ(trace, ans | problem) = Pθ(ans | trace, problem) · Pθ(trace | problem)

In the limit of an infinitely expressive model, the joint over (trace, ans) marginalizes to the same P(ans | problem) — trace is a latent we are summing over. So at the level of pure probability semantics, the two factorizations are equivalent.

The point is precisely that we are not in the limit of infinite expressivity. The two factorizations are equivalent only if the model can compute P(ans | problem) directly. For a fixed-depth model on a problem whose serial depth exceeds L, it cannot. The CoT factorization gives the model a path: sample each token of trace autoregressively, conditioning the next on the partial trace already produced; by the time we reach ans, the relevant intermediate quantity (the running total just before the last digit) is sitting in the context as a literal token, and computing P(ans | trace, problem) reduces to reading off a number from a string that the model itself wrote.

Notice the shape of Pθ(ans | trace, problem): the trace is on the same side of the conditioning bar as the problem. From the perspective of computing the answer, the trace has been promoted from a latent variable to a piece of evidence. The model gets to read what it wrote. Each answer token is now conditioned on a much richer context than the raw problem alone — a context that has been specifically arranged, by the model's own preceding decoding steps, to make the answer easy to predict in one forward pass.

Slogan
CoT is a way of arranging the context so that the conditional distribution the model has to learn becomes shallow. The total work is no smaller — in fact it is larger, by a factor of R — but it is decomposed into pieces each of which fits inside one forward pass.

Why CoT is a data choice, not a loss choice

Now the punchline. The training loop in 02_cot.py is, line for line:

x, y = ds.get_batch(B, device)
logits, _ = model(x)
loss = F.cross_entropy(
    logits.view(-1, tok.vocab_size), y.view(-1),
    ignore_index=IGNORE_INDEX,
)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()

This is byte-for-byte identical to the SFT training loop from lesson 3. There is no auxiliary loss term for “reasoning quality”. There is no special weighting of trace tokens versus answer tokens. There is no separate forward pass over the trace, no separate optimizer state. The model treats every token of the response — whether it lies inside {...} or after } — as a position to predict via cross-entropy, with the prompt masked out via IGNORE_INDEX and the padding masked out via the same mechanism.

The only delta is in the CoTDataset class:

response = cot if use_cot else plain

That is the entire stage. One ternary, hidden inside the dataset constructor. The training loop, the model definition, the chat template parser, the eval harness — none of it knows that anything is different. From the optimizer's perspective, lesson 4 and lesson 3 are the same job; only the strings being learned have changed.

The lesson punchline
Adding chain-of-thought to a model is a curriculum decision, not a training-procedure decision. If your training pipeline can already do prompt-masked cross-entropy on (prompt, response) pairs, it can already do CoT. You just write better responses.

This matters for a reason that goes beyond engineering convenience. The fact that CoT does not require a new loss tells you that the model's ability to reason in steps is already latent in the SFT mechanism. The base model after pretraining has learned, from web text, a fluent prior over “reasoning-shaped” strings. SFT on direct answers actively trains against that prior, by spending all the response-position gradient on the answer alone. SFT on CoT-formatted answers trains with that prior, by letting the response-position gradient flow through reasoning structure too. You are not teaching the model a new skill; you are removing a constraint you accidentally imposed in the previous stage.

Compute redistribution

Step back and look at where the capability is being paid for. There are two main ways a language model can “know more”:

  1. Parameters. More weights, deeper stacks, wider attention. Pays for itself at training time. Fixed at inference. Expensive (you have to actually train them, and they show up in every forward pass forever).
  2. Tokens. Longer responses, allowed to think through intermediate steps. Pays for itself at inference time. Variable per request. Cheap (the parameters are already paid for; you just generate more).

CoT is the slogan-version of the second knob. “Think step by step” is a request for the model to spend more inference-time compute — specifically, more serial compute — rather than asking it to be a bigger model. Empirically the trade is favorable for problems where serial depth is the bottleneck: doubling response length is cheap; doubling parameter count is not.

This connects to scaling laws in a clean way. Holding compute constant, you can spend it on parameters (Chinchilla-style scaling) or on response tokens (CoT, agentic loops, search). The first is a one-time cost amortized over all future queries; the second is a per-query cost. Modern reasoning systems are best understood as having shifted the budget meaningfully toward the second — a frontier-scale model is increasingly viewed as a fast, expensive engine to be called many times per question, rather than a slower, cheaper engine called once.

The result shape

What the experiment in 02_cot.py measures is the per-difficulty accuracy of the two regimes. Both are trained for the same number of steps on the same problem distribution with the same architecture; the only difference is whether the response is ans# or {trace}ans#. Difficulty is parameterized by K, the number of summands. The shape of the result, drawn from a representative run:

K direct CoT gap ─ ────── ────── ───── 2 95% 98% +3% 3 80% 96% +16% 4 55% 94% +39% 5 30% 92% +62% 6 15% 90% +75%

These are illustrative numbers, not measured ones — the exact figures from any given run will differ depending on seed, optimizer noise, and how the random problem sampler happens to populate each bucket. What is stable across runs is the shape: direct-mode accuracy is a steeply decreasing function of K, while CoT-mode accuracy is roughly flat. The gap widens with difficulty.

Why this shape? The direct model is forced to compress all of the problem-specific computation into the single forward pass that emits the answer token. For K=2 — a 100-entry lookup table — the MLPs in a 128-dimensional model trained for 2000 steps can memorize the table without breaking a sweat. For K=6 — a million-entry lookup table — they cannot, and the model has no fallback because it never learned to use the natural algorithm. The CoT model, in contrast, only ever solves one-step problems: at every reasoning token, the subproblem is “single-digit addition between two values both visible in the context”, which is solidly inside the single-forward-pass budget. The model only has to be good at the one-step problem to be good at the K-step problem.

Interactive: serial-compute simulator

To make the budget picture tangible, here is a model of the compute available in each regime as a function of L and K. The simulator does not simulate a transformer — it simulates the budget. Move the sliders; watch direct-mode flip from “solvable” to “overflow” as K exceeds L; watch CoT-mode stay solvable because it can spend as many forward passes as it needs.

Direct vs CoT: serial-compute budget
Direct mode has one forward pass to do K−1 serial adds. CoT mode generates one step per reasoning token, each a fresh forward pass. The bottom chart shows illustrative predicted accuracy — smooth curves, not measured.
Direct: serial adds needed
5
Direct: slots available
4
Direct: can solve?
no
CoT: response tokens
~30
CoT: total layer-ops
120
CoT: can solve?
yes

Where CoT breaks down in this toy

The model in 02_cot.py is tiny: L=4, d=128, h=4. The training set is 3000 problems for 2000 steps with batch 64 — on the order of 128k example exposures. This is enough capacity to memorize trivial lookup tables and to learn a small finite-state machine for the CoT template, but not enough to brute-force a K=6 lookup. The behavior you observe is consistent with the budget picture above, but it is worth being precise about which failure modes you are seeing.

The interesting failure mode of CoT in this toy is format brittleness: occasionally the model miscounts the number of steps it has done, or fails to close } at the right place, or hallucinates an extra ;. When this happens, the answer-extraction regex returns nothing or returns a partial number, and the example is marked wrong even though the trace contains the correct intermediate result. In a small model trained for short, this is the dominant residual error: not arithmetic errors, but bookkeeping errors. R1 dodges this by using dedicated tokens for thinking versus answer; here we share the same character vocabulary, so the model has to learn the structure from scratch.

Relationship to o1 and R1

The mechanism we are studying is exactly the mechanism that powers OpenAI's o1, DeepSeek's R1, and the broader class of “reasoning models”. The differences are scale and origin:

The CoT-SFT setup of this lesson is the supervised analog of what RLVR (lesson 6) does under the hood. SFT-on-CoT requires you to have a gold reasoning trace for every problem; RLVR requires only a gold answer plus a verifier that can check whether the model's own trace ended at the right answer. Whenever you have access to verifiable rewards, the RL route is more scalable: you don't have to hand-write traces, and the model is free to discover styles of reasoning you would not have thought of.

Trade-offs

CoT is not free. The places where it costs real money or real reliability:

The right way to frame the trade-off: CoT converts a hard compute ceiling (set at training time, applies to every query equally) into a soft cost surface (set at inference time, scales linearly per query). Whether that is a good trade depends on whether your bottleneck is “the model is fundamentally not able to do this” (use CoT) or “the model is able to do this and we want it cheap and fast” (don't).

Forward link to DPO

CoT-SFT still demands one gold trace per problem. For arithmetic that is easy; the running-total algorithm is unambiguous and our problem generator literally writes the trace itself. For real reasoning tasks — a math olympiad problem, a code-review comment, an essay critique — what counts as a “gold” trace is not obvious. There may be many acceptable solutions; the demonstrator may produce a trace that is correct but not what the model would naturally write; the model's own first attempt might already be 80% correct and only need a small nudge.

What if, instead of trying to write the perfect trace, we let the model produce two candidate traces, ranked one against the other? “Trace A is better than trace B for this problem” is much cheaper to obtain than “here is the perfect trace for this problem”. This is the data shape of preference learning. Closing the gap between SFT-style absolute targets and preference-style relative targets is what lesson 5 is about.