Masked & discrete diffusion
The discrete analog of DDPM. MaskGIT’s parallel decoding. How to sample N tokens in ~10 forward passes instead of N.
The setup, by analogy
Continuous DDPM corrupts data by adding Gaussian noise in T steps; the model learns to denoise. Discrete diffusion does the same with a categorical corruption. Define a forward Markov chain on token sequences z0, z1, …, zT with
where the transition matrix says “with probability 1 − βt keep the token unchanged; with probability βt, do something corrupting.” The two big choices for “something”:
| Choice | What corruption looks like | Used in |
|---|---|---|
| Absorbing (mask) | token becomes a special [MASK] symbol; once masked, stays masked | MaskGIT, MDLM, MD4, most image work |
| Uniform (swap) | token becomes a uniformly random vocab entry | D3PM uniform variant, less used in practice |
| Edit distance (swap to neighbor) | token swaps to a similar one (e.g. semantically near in embedding) | SEDD, MD4-edit, text-focused work |
For images, absorbing (mask-only) wins almost always. The reason: image patches don’t have meaningful “edit distance” in token space, while “present vs. missing” is a clean signal the model can condition on.
The closed-form marginal — discrete Eq. ⋆
For absorbing corruption, the marginal probability of token zt given z0 is:
where γt ∈ [0, 1] is the cumulative mask probability at step t — the discrete analog of the noise-mass fraction (1 − ᾱt) from DDPM (Eq. ⋆ in lesson 2). Both are probabilities / fractions in [0, 1] measuring accumulated corruption — neither is a standard deviation. Pick any monotone γt with γ0 = 0, γT = 1; linear, cosine, sigmoid all work. To sample zt, independently mask each token with probability γt. Closed-form, simulation-free, exactly like DDPM’s Eq. ⋆.
The training objective
Like DDPM’s ELBO collapsing to MSE, the discrete ELBO collapses to a sum of cross-entropies over the masked positions:
i.e. predict the clean token at every masked position, given the partially-masked sequence. That’s BERT’s masked-language-modeling objective with a single twist: the mask rate is sampled per example (so the model sees every fraction from 0% to 100% masked, not just a fixed 15%).
If you train at a fixed mask rate you’ve trained BERT. If you train at variable mask rate, you’ve trained a discrete diffusion model. The architecture is the same.
Sampling: the MaskGIT loop
Given a trained pθ(z0 | zt), here is the sampler in five lines:
# start with everything masked
z = [MASK] * N
for r in range(R): # R rounds ~ 10
# one forward pass over the whole sequence
logits = model(z, t = (R-r)/R)
pred = sample(logits) # categorical per position
conf = max(softmax(logits), axis=-1) # confidence per position
# cosine mask schedule: mask_frac shrinks from 1 toward 0 as r grows
mask_frac = cos((r+1)/R · π/2) # ∈ [cos(π/2R), 0]
n_masked = ceil(N * mask_frac) # cells still masked after this round
n_keep = N - n_masked # cells revealed after this round
# take the n_keep most confident positions; mask the rest back to [MASK]
z = commit_top_k_by_confidence(pred, conf, n_keep)
Reading the schedule the right way: n_masked is what stays masked after round r. Cosine starts at near-1 (almost everything still masked after round 0), shrinks fast in the middle, hits 0 by round R. The dual quantity n_keep = N − n_masked is what the model commits to after the round. The original MaskGIT paper writes this as “mask schedule” rather than “unmask schedule”; both are valid as long as you keep the polarity straight.
Each round costs one forward pass. After R rounds, the sequence is fully filled. For a typical image at 256 tokens and R = 12, that’s 12 forwards instead of 256 — 20× faster than autoregression.
Interactive · MaskGIT rounds on a synthetic 64-token grid
Below, an 8×8 token grid representing a small image. The “model” is an oracle: every position has a true preferred color, and the predicted-confidence is a function of how much of the neighborhood is already filled (so committed cells help neighbors become more confident). Drag rounds from low to high and watch the trade-off.
What the schedule does
The unmask schedule decides how fast we commit. Three common choices:
- Linear: at round r reveal (r+1)/R of cells. Even pacing.
- Cosine (default in MaskGIT): reveal slowly at first, then fast. Lets the model take many small bites early when context is poor, large bites late when context is rich.
- Sqrt: reveal fast at first, slow at the end. Front-loads commitments — rarely the right choice for images (context-rich regions deserve more, not less, deferral).
The trade-off is “number of rounds” vs. “quality.” A common operating point: R = 12 with cosine, matching the original MaskGIT paper.
Discrete diffusion beyond masking
D3PM (Austin et al. 2021) generalized to arbitrary discrete corruption matrices — not just masking. The math is identical to DDPM but with categorical KLs instead of Gaussian KLs:
For absorbing corruption the KL has the same closed form as the masked-CE above; for other corruptions you compute the categorical KL directly. Recent work (SEDD, MD4) makes this competitive on text where token swaps matter; for image tokens, the absorbing case dominates.
Why this matters for the multimodal stack
Parallel decoding is the standard research choice for token-based image generators that prioritize throughput (Muse, MAGVIT-v2 are the canonical examples). Flagship product systems — Nano Banana Pro, the GPT-Image-2 family — don’t publicly document their decoder strategy; reasonable inferences from latency profiles and reported capabilities suggest hybrid approaches (parallel decoding for the bulk of image tokens, with autoregressive or diffusion stages around the edges). The architectural option is what matters here: any transformer trained with a masked image-token objective can run either decoder. Reasons to lean parallel for the image span:
- Throughput. 12 forwards beats 256 forwards by a lot, especially when image generation latency dominates user-perceived latency.
- Editing. Parallel decoders trivially support “regenerate this region”: just mask the region and run a few more rounds. Autoregression has to either start over or use a clever inpainting recipe.
- Bidirectional context. Image patches are not naturally ordered. Parallel decoding lets every position attend to every other from round 1; autoregression has to scan a fixed order.
The text branch of these models usually stays autoregressive (that’s where chain-of-thought lives — lesson 14). The image branch flips to parallel. A single transformer can do both: the only thing that changes is whether you decode left-to-right or all-at-once.
Trade-offs in summary
| Axis | Autoregressive | MaskGIT / parallel | Continuous diffusion |
|---|---|---|---|
| Forward passes per image | N (= seq length) | ~10–20 | ~20–50 (FM) or ~1000 (DDPM) |
| Editing | hard | trivial (remask region) | inpainting via masked sampling |
| Bidirectional context | no (causal mask) | yes (no mask in attention) | n/a (operates on whole image at once) |
| Reasoning + text in same model | natural (just one LM) | natural (same LM, parallel-decoded image span) | requires separate conditioning interface (lesson 15) |
| Token-level loss | cross-entropy per position | cross-entropy on masked positions | MSE (continuous target) |
| Sample quality (image, matched compute) | highest | nearly AR, much faster | highest on raw pixel diffusion; competitive in latent diffusion |