Mixed precision — bf16, fp16, fp8
Lower precision halves memory and roughly doubles tensor-core throughput. Doing it without losing convergence is a small set of careful rules — and the rules are different for each format.
What "precision" buys, exactly
A floating-point number has three pieces: a sign bit, an exponent (range), and a mantissa (precision within an exponent). The IEEE/standard ML formats:
| Format | Bits | Sign / exp / mantissa | Max | Min (norm) | Approx decimal digits |
|---|---|---|---|---|---|
| fp32 | 32 | 1 / 8 / 23 | ~3.4 × 10³⁸ | ~1.2 × 10⁻³⁸ | 7 |
| fp16 | 16 | 1 / 5 / 10 | ~6.5 × 10⁴ | ~6.1 × 10⁻⁵ | 3–4 |
| bf16 | 16 | 1 / 8 / 7 | ~3.4 × 10³⁸ | ~1.2 × 10⁻³⁸ | 2–3 |
| fp8 E4M3 | 8 | 1 / 4 / 3 | 448 | ~1.95 × 10⁻³ | ~1 |
| fp8 E5M2 | 8 | 1 / 5 / 2 | 57344 | ~6.1 × 10⁻⁵ | <1 |
Two observations make most of the rest of this lesson:
- bf16 and fp32 have the same exponent (8 bits). Same range; bf16 just has worse precision within each exponent. This is why bf16 nearly drop-in replaces fp32 — gradients can't overflow or vanish.
- fp16 has 5 exponent bits, max ≈ 65k. Gradients in real training routinely exceed 65k after summing across a batch or through long chains — they overflow to infinity. fp16 needs a "loss scaler" to avoid this. bf16 doesn't.
2D · bit-format visualiser
Side-by-side bit layouts. Type a number and watch how each format encodes it: the sign bit, the exponent bits (shifted by bias), and the mantissa bits. Notice how bf16 truncates fp32's mantissa but keeps the exponent, while fp16 keeps mantissa precision at the cost of range.
Why bf16 won
In 2018, the only widespread 16-bit format on GPUs was fp16, and people invented elaborate machinery (loss scaling, gradient unscaling, dynamic scaler adjustment) to make it work. NVIDIA's A100 (2020) added bf16 hardware support. By 2022, bf16 became the default for training. The reason is just the range: with bf16 you can take any fp32 training recipe, swap the activations and gradients to bf16, and it works. With fp16 you have to manage a scaler — and even then, the loss landscape's tail events sometimes overflow.
The cost of bf16 is precision. With 7 mantissa bits, accumulating a long sum loses information rapidly — adding a small number to a large one rounds to the large one (the small number is below the LSB of the large one's mantissa). The workaround is one of the foundational rules:
The master-weights trick (recap from lesson 01)
Even with accumulate-in-fp32 inside ops, the parameter update step still loses precision in bf16:
For a typical LR η = 10⁻⁴ and gradient g ~ 1, the update is 10⁻⁴. A bf16 weight near 1.0 has LSB ≈ 2⁻⁷ ≈ 8 × 10⁻³. The update rounds to zero. Updates simply don't accumulate, training stalls.
The fix: keep a master copy of weights in fp32. The bf16 weight is a cast of the master. Forward and backward use bf16. The optimizer step updates the fp32 master, then the bf16 working copy is refreshed from it. This is the "12 bytes of optimizer state + master" in lesson 01's bytes-per-param accounting.
fp16 and the loss scaler — the historical wart
fp16 gradients overflow because the range maxes out at ~65k. The fix is the loss scaler:
- Multiply the loss by some scale factor S (typically 2¹⁵) before backward.
- Backward proceeds; gradients are S× their natural magnitude — but now in a healthy fp16 range, away from the underflow floor.
- Before the optimizer step, "unscale" by dividing gradients by S.
- Optimizer step in fp32 as usual.
If any gradient overflows to inf during step 2, PyTorch's GradScaler detects it, skips the step entirely, and halves S for next time. If many steps pass without overflow, S is doubled back up. This is the dynamic loss-scaling algorithm. bf16 deletes all of this machinery.
fp8 — two formats, training vs inference
fp8 lives in the next-precision-down regime. NVIDIA's Hopper (H100) and Blackwell (B100/B200) have tensor cores that consume fp8 inputs and produce fp16 or fp32 accumulators. Two fp8 formats:
- E4M3 (4 exponent, 3 mantissa). Range ~448, more mantissa. Used for weights and activations.
- E5M2 (5 exponent, 2 mantissa). Range ~57k, less mantissa, more like fp16 in range. Used for gradients.
The asymmetry is intentional. Weights and activations have small dynamic range (post-LayerNorm), so a tight-range high-precision format works. Gradients can span many orders of magnitude (back-propagated through long networks), so a wide-range low-precision format works. NVIDIA's Transformer Engine library manages the choice.
fp8 training (DeepSeek-V3, recent Megatron) typically combines:
- fp8 matmul (E4M3 × E4M3 with fp32 accumulator for forward; E4M3 × E5M2 for backward).
- fp32 master weights and Adam state still dominate the bytes — that's 12 bytes per param (4 master + 4 m + 4 v) on its own. Adding fp8 weights + fp8 grads gives ~14 B/param, plus a handful of bytes for per-tensor scale factors. So fp8 training trims roughly 2 B/param off the lesson-01 mixed-precision figure, not half of it — the optimizer state is what makes the parameter budget mostly format-independent.
- Per-tensor scaling factors that adjust dynamically (otherwise the limited range causes weight or activation overflow).
- Some "stay in higher precision" ops: norms, the LM head, the embedding table.
The serving win is bigger than the training win: fp8 weights are half the size of bf16, so HBM traffic per decode token halves, so decode latency halves (lesson 11's HBM-bound math). H100's fp8 throughput is roughly 2× bf16; B200 pushes it further with fp4 in places.
2D · E4M3 vs E5M2, the same value placed on each
Two log-scale axes side-by-side. Every representable positive value of each fp8 format is shown as a dot. Drop a value with the slider and see which format keeps it in range, which clamps to infinity, and what the rounding error looks like. E4M3 has more dots clustered in a tight range; E5M2 covers more orders of magnitude but spaces them further apart.
Autocast — the developer ergonomic
You don't manually cast every tensor. PyTorch's torch.autocast(device='cuda', dtype=torch.bfloat16) context manager wraps ops in a dispatcher layer (the "Autocast" key from lesson 13) that automatically casts inputs to the right type:
- bf16-safe ops (matmul, conv, attention): inputs cast to bf16.
- fp32-needed ops (softmax, LayerNorm, sum reductions, loss functions): inputs cast to fp32.
- Preserved ops: simple elementwise pass-through, follow input type.
The cast is cheap (a fused kernel can do it for free), but the rules are opinionated: PyTorch decides for you which ops are precision-sensitive. Most of the time it's right. For fp8 training, NVIDIA's Transformer Engine extends this with per-tensor scales.
Animated · autocast data flow through one step
Forward + backward of a single layer with autocast enabled. Each arrow is a tensor; its color encodes the dtype the framework actually stores. Watch the matmul take bf16 inputs but produce an fp32 accumulator before casting back, the gradient flowing back in bf16, and the optimizer step happening on the fp32 master weights.
The numerics traps that survive autocast
Even with autocast on, a few classes of bugs reappear:
- Reductions over many elements in low precision. A custom
x.sum(dim=-1)on a bf16 tensor with 10k elements drops bits at the end of the accumulation. Cast to fp32 first. - Pseudo-fp32 work in user code. Anything you do with
torch.tensordirectly, especially RoPE / position-embedding math, may not be autocast-aware. Stay in fp32 for the math, cast result back. - Layer norms. The accumulator inside LayerNorm should be fp32 — modern PyTorch does this, but third-party implementations occasionally don't.
- Custom CUDA / Triton kernels. You're responsible for the accumulator dtype yourself.
The decision table
| Goal | Format | Master weights? | Scaler? |
|---|---|---|---|
| Pretraining (default) | bf16 | fp32 | No |
| Pretraining on older GPUs (V100, T4) | fp16 | fp32 | Yes (dynamic) |
| Frontier-scale pretraining (H100/B100) | fp8 (E4M3 / E5M2) | fp32 + per-tensor scales | Different — per-tensor scale calibration |
| Serving / inference | fp8 or int8 weights, bf16 activations | — | — |
| Tiny models / debugging | fp32 throughout | — | — |
Interactive · feel the range & precision tradeoff
Pick a value and a format; the widget shows where it lands (the nearest representable number), the relative error, and what happens to the sum a + b when b ≪ a. Try entering 65000 in fp16: it survives. Try 70000: overflows to inf. Try summing 1.0 + 1e-4 in bf16: the small term disappears.