system_ml / 14 · mixed precision lesson 14 / 19

Mixed precision — bf16, fp16, fp8

Lower precision halves memory and roughly doubles tensor-core throughput. Doing it without losing convergence is a small set of careful rules — and the rules are different for each format.

What "precision" buys, exactly

A floating-point number has three pieces: a sign bit, an exponent (range), and a mantissa (precision within an exponent). The IEEE/standard ML formats:

FormatBitsSign / exp / mantissaMaxMin (norm)Approx decimal digits
fp32321 / 8 / 23~3.4 × 10³⁸~1.2 × 10⁻³⁸7
fp16161 / 5 / 10~6.5 × 10⁴~6.1 × 10⁻⁵3–4
bf16161 / 8 / 7~3.4 × 10³⁸~1.2 × 10⁻³⁸2–3
fp8 E4M381 / 4 / 3448~1.95 × 10⁻³~1
fp8 E5M281 / 5 / 257344~6.1 × 10⁻⁵<1

Two observations make most of the rest of this lesson:

2D · bit-format visualiser

Side-by-side bit layouts. Type a number and watch how each format encodes it: the sign bit, the exponent bits (shifted by bias), and the mantissa bits. Notice how bf16 truncates fp32's mantissa but keeps the exponent, while fp16 keeps mantissa precision at the cost of range.

Float formats · bit-level view
Each row is one format. Blue = sign, orange = exponent, green = mantissa. The number on the right is the actual value that gets stored (after rounding/overflow). Try 0.0001, 65000, 70000, 1e-10.
stored as bf16
stored as fp16
stored as fp8 E4M3
stored as fp8 E5M2

Why bf16 won

In 2018, the only widespread 16-bit format on GPUs was fp16, and people invented elaborate machinery (loss scaling, gradient unscaling, dynamic scaler adjustment) to make it work. NVIDIA's A100 (2020) added bf16 hardware support. By 2022, bf16 became the default for training. The reason is just the range: with bf16 you can take any fp32 training recipe, swap the activations and gradients to bf16, and it works. With fp16 you have to manage a scaler — and even then, the loss landscape's tail events sometimes overflow.

The cost of bf16 is precision. With 7 mantissa bits, accumulating a long sum loses information rapidly — adding a small number to a large one rounds to the large one (the small number is below the LSB of the large one's mantissa). The workaround is one of the foundational rules:

The accumulate-in-fp32 rule
Matmul inputs are bf16; the accumulator inside the tensor core is fp32. Each multiply contributes a bf16×bf16 product (which fits exactly in fp32), then a fp32 add to the running sum. The result tensor is cast back to bf16. NVIDIA's bf16 tensor cores do this without you asking; you'd only break it by manually using bf16 throughout. Reductions in cuDNN/cuBLAS follow the same rule.

The master-weights trick (recap from lesson 01)

Even with accumulate-in-fp32 inside ops, the parameter update step still loses precision in bf16:

θ_new  =  θ_old  −  η · g

For a typical LR η = 10⁻⁴ and gradient g ~ 1, the update is 10⁻⁴. A bf16 weight near 1.0 has LSB ≈ 2⁻⁷ ≈ 8 × 10⁻³. The update rounds to zero. Updates simply don't accumulate, training stalls.

The fix: keep a master copy of weights in fp32. The bf16 weight is a cast of the master. Forward and backward use bf16. The optimizer step updates the fp32 master, then the bf16 working copy is refreshed from it. This is the "12 bytes of optimizer state + master" in lesson 01's bytes-per-param accounting.

fp16 and the loss scaler — the historical wart

fp16 gradients overflow because the range maxes out at ~65k. The fix is the loss scaler:

  1. Multiply the loss by some scale factor S (typically 2¹⁵) before backward.
  2. Backward proceeds; gradients are S× their natural magnitude — but now in a healthy fp16 range, away from the underflow floor.
  3. Before the optimizer step, "unscale" by dividing gradients by S.
  4. Optimizer step in fp32 as usual.

If any gradient overflows to inf during step 2, PyTorch's GradScaler detects it, skips the step entirely, and halves S for next time. If many steps pass without overflow, S is doubled back up. This is the dynamic loss-scaling algorithm. bf16 deletes all of this machinery.

fp8 — two formats, training vs inference

fp8 lives in the next-precision-down regime. NVIDIA's Hopper (H100) and Blackwell (B100/B200) have tensor cores that consume fp8 inputs and produce fp16 or fp32 accumulators. Two fp8 formats:

The asymmetry is intentional. Weights and activations have small dynamic range (post-LayerNorm), so a tight-range high-precision format works. Gradients can span many orders of magnitude (back-propagated through long networks), so a wide-range low-precision format works. NVIDIA's Transformer Engine library manages the choice.

fp8 training (DeepSeek-V3, recent Megatron) typically combines:

The serving win is bigger than the training win: fp8 weights are half the size of bf16, so HBM traffic per decode token halves, so decode latency halves (lesson 11's HBM-bound math). H100's fp8 throughput is roughly 2× bf16; B200 pushes it further with fp4 in places.

2D · E4M3 vs E5M2, the same value placed on each

Two log-scale axes side-by-side. Every representable positive value of each fp8 format is shown as a dot. Drop a value with the slider and see which format keeps it in range, which clamps to infinity, and what the rounding error looks like. E4M3 has more dots clustered in a tight range; E5M2 covers more orders of magnitude but spaces them further apart.

fp8 range vs precision · the two formats, on log axes
Black dots are the positive normal values each format can represent. Yellow line is your input. Green dot is the nearest representable value (rounded). The red zone past the max means overflow → ±inf.
E4M3 rounded
E4M3 rel err
E5M2 rounded
E5M2 rel err

Autocast — the developer ergonomic

You don't manually cast every tensor. PyTorch's torch.autocast(device='cuda', dtype=torch.bfloat16) context manager wraps ops in a dispatcher layer (the "Autocast" key from lesson 13) that automatically casts inputs to the right type:

The cast is cheap (a fused kernel can do it for free), but the rules are opinionated: PyTorch decides for you which ops are precision-sensitive. Most of the time it's right. For fp8 training, NVIDIA's Transformer Engine extends this with per-tensor scales.

Animated · autocast data flow through one step

Forward + backward of a single layer with autocast enabled. Each arrow is a tensor; its color encodes the dtype the framework actually stores. Watch the matmul take bf16 inputs but produce an fp32 accumulator before casting back, the gradient flowing back in bf16, and the optimizer step happening on the fp32 master weights.

Autocast flow · one fwd+bwd+step
Click play to walk the data flow. Tensor colors: blue=fp32, orange=bf16, green=fp32 accumulator, red=overflow risk. The "step" lane shows the optimizer reading the bf16 grad and writing the fp32 master.
phase
current tensor
dtype in flight
bytes / param

The numerics traps that survive autocast

Even with autocast on, a few classes of bugs reappear:

The decision table

GoalFormatMaster weights?Scaler?
Pretraining (default)bf16fp32No
Pretraining on older GPUs (V100, T4)fp16fp32Yes (dynamic)
Frontier-scale pretraining (H100/B100)fp8 (E4M3 / E5M2)fp32 + per-tensor scalesDifferent — per-tensor scale calibration
Serving / inferencefp8 or int8 weights, bf16 activations
Tiny models / debuggingfp32 throughout

Interactive · feel the range & precision tradeoff

Pick a value and a format; the widget shows where it lands (the nearest representable number), the relative error, and what happens to the sum a + b when b ≪ a. Try entering 65000 in fp16: it survives. Try 70000: overflows to inf. Try summing 1.0 + 1e-4 in bf16: the small term disappears.

Floating-point precision · the actual representable values
Each format's representable values are quantised. Slide your number and watch which format still captures it. Then watch a + b when the formats disagree about how to round.
formatrepr(a)repr(b)repr(a+b)err vs fp64
overflows in fp8 E4M3
overflows in fp16
b disappears in bf16
b disappears in fp16
Takeaway
bf16 = same range as fp32, less mantissa — drop-in. fp16 = less range, more mantissa — needs a loss scaler. fp8 = even less range, used with per-tensor scales and an asymmetric weight/activation vs gradient format split. In all three, the optimizer keeps fp32 master weights — that's why "16 bytes per param" in lesson 01 was robust across precisions. Don't try to "go lower" without understanding what each step costs.