generative_continuous / 12 · discrete diffusion lesson 12 / 15

Masked & discrete diffusion

The discrete analog of DDPM. MaskGIT’s parallel decoding. How to sample N tokens in ~10 forward passes instead of N.

The setup, by analogy

Continuous DDPM corrupts data by adding Gaussian noise in T steps; the model learns to denoise. Discrete diffusion does the same with a categorical corruption. Define a forward Markov chain on token sequences z0, z1, …, zT with

q(zt | zt−1) = Cat( zt ; transition_matrix(βt) )

where the transition matrix says “with probability 1 − βt keep the token unchanged; with probability βt, do something corrupting.” The two big choices for “something”:

ChoiceWhat corruption looks likeUsed in
Absorbing (mask)token becomes a special [MASK] symbol; once masked, stays maskedMaskGIT, MDLM, MD4, most image work
Uniform (swap)token becomes a uniformly random vocab entryD3PM uniform variant, less used in practice
Edit distance (swap to neighbor)token swaps to a similar one (e.g. semantically near in embedding)SEDD, MD4-edit, text-focused work

For images, absorbing (mask-only) wins almost always. The reason: image patches don’t have meaningful “edit distance” in token space, while “present vs. missing” is a clean signal the model can condition on.

The closed-form marginal — discrete Eq. ⋆

For absorbing corruption, the marginal probability of token zt given z0 is:

q(zt = MASK | z0) = γt,   q(zt = z0 | z0) = 1 − γt

where γt ∈ [0, 1] is the cumulative mask probability at step t — the discrete analog of the noise-mass fraction (1 − ᾱt) from DDPM (Eq. ⋆ in lesson 2). Both are probabilities / fractions in [0, 1] measuring accumulated corruption — neither is a standard deviation. Pick any monotone γt with γ0 = 0, γT = 1; linear, cosine, sigmoid all work. To sample zt, independently mask each token with probability γt. Closed-form, simulation-free, exactly like DDPM’s Eq. ⋆.

The training objective

Like DDPM’s ELBO collapsing to MSE, the discrete ELBO collapses to a sum of cross-entropies over the masked positions:

L = 𝔼t, z0, mask [ Σi : masked −log pθ(z0,i | zt, t) ]

i.e. predict the clean token at every masked position, given the partially-masked sequence. That’s BERT’s masked-language-modeling objective with a single twist: the mask rate is sampled per example (so the model sees every fraction from 0% to 100% masked, not just a fixed 15%).

If you train at a fixed mask rate you’ve trained BERT. If you train at variable mask rate, you’ve trained a discrete diffusion model. The architecture is the same.

Sampling: the MaskGIT loop

Given a trained pθ(z0 | zt), here is the sampler in five lines:

# start with everything masked
z = [MASK] * N
for r in range(R):                       # R rounds ~ 10
    # one forward pass over the whole sequence
    logits = model(z, t = (R-r)/R)
    pred   = sample(logits)              # categorical per position
    conf   = max(softmax(logits), axis=-1)  # confidence per position
    # cosine mask schedule: mask_frac shrinks from 1 toward 0 as r grows
    mask_frac = cos((r+1)/R · π/2)        # ∈ [cos(π/2R), 0]
    n_masked  = ceil(N * mask_frac)       # cells still masked after this round
    n_keep    = N - n_masked              # cells revealed after this round
    # take the n_keep most confident positions; mask the rest back to [MASK]
    z = commit_top_k_by_confidence(pred, conf, n_keep)

Reading the schedule the right way: n_masked is what stays masked after round r. Cosine starts at near-1 (almost everything still masked after round 0), shrinks fast in the middle, hits 0 by round R. The dual quantity n_keep = N − n_masked is what the model commits to after the round. The original MaskGIT paper writes this as “mask schedule” rather than “unmask schedule”; both are valid as long as you keep the polarity straight.

Each round costs one forward pass. After R rounds, the sequence is fully filled. For a typical image at 256 tokens and R = 12, that’s 12 forwards instead of 256 — 20× faster than autoregression.

Why “keep the most confident”?
At each round the model gives you a guess for every masked cell. Some guesses are confident (sharp posterior), some are diffuse. If you commit to all of them at once, the diffuse ones are essentially random. By committing only to the confident ones and remasking the rest, you give the model another round of conditioning on the committed cells — the diffuse cells’ posteriors sharpen because they now have more context. This is the same trick beam search uses to defer uncertain decisions.

Interactive · MaskGIT rounds on a synthetic 64-token grid

Below, an 8×8 token grid representing a small image. The “model” is an oracle: every position has a true preferred color, and the predicted-confidence is a function of how much of the neighborhood is already filled (so committed cells help neighbors become more confident). Drag rounds from low to high and watch the trade-off.

MaskGIT parallel decoding on an 8×8 grid
Press play. Each round: predict ALL masked cells, keep top-K most confident, remask the rest. Watch the picture lock in from confident regions outward.
round
0 / —
cells revealed
0 / 64
forwards so far
0
errors vs. target

What the schedule does

The unmask schedule decides how fast we commit. Three common choices:

The trade-off is “number of rounds” vs. “quality.” A common operating point: R = 12 with cosine, matching the original MaskGIT paper.

Discrete diffusion beyond masking

D3PM (Austin et al. 2021) generalized to arbitrary discrete corruption matrices — not just masking. The math is identical to DDPM but with categorical KLs instead of Gaussian KLs:

Lt = 𝔼z0, zt [ KL( q(zt−1 | zt, z0) ‖ pθ(zt−1 | zt) ) ]

For absorbing corruption the KL has the same closed form as the masked-CE above; for other corruptions you compute the categorical KL directly. Recent work (SEDD, MD4) makes this competitive on text where token swaps matter; for image tokens, the absorbing case dominates.

Why this matters for the multimodal stack

Parallel decoding is the standard research choice for token-based image generators that prioritize throughput (Muse, MAGVIT-v2 are the canonical examples). Flagship product systems — Nano Banana Pro, the GPT-Image-2 family — don’t publicly document their decoder strategy; reasonable inferences from latency profiles and reported capabilities suggest hybrid approaches (parallel decoding for the bulk of image tokens, with autoregressive or diffusion stages around the edges). The architectural option is what matters here: any transformer trained with a masked image-token objective can run either decoder. Reasons to lean parallel for the image span:

  1. Throughput. 12 forwards beats 256 forwards by a lot, especially when image generation latency dominates user-perceived latency.
  2. Editing. Parallel decoders trivially support “regenerate this region”: just mask the region and run a few more rounds. Autoregression has to either start over or use a clever inpainting recipe.
  3. Bidirectional context. Image patches are not naturally ordered. Parallel decoding lets every position attend to every other from round 1; autoregression has to scan a fixed order.

The text branch of these models usually stays autoregressive (that’s where chain-of-thought lives — lesson 14). The image branch flips to parallel. A single transformer can do both: the only thing that changes is whether you decode left-to-right or all-at-once.

Trade-offs in summary

AxisAutoregressiveMaskGIT / parallelContinuous diffusion
Forward passes per imageN (= seq length)~10–20~20–50 (FM) or ~1000 (DDPM)
Editinghardtrivial (remask region)inpainting via masked sampling
Bidirectional contextno (causal mask)yes (no mask in attention)n/a (operates on whole image at once)
Reasoning + text in same modelnatural (just one LM)natural (same LM, parallel-decoded image span)requires separate conditioning interface (lesson 15)
Token-level losscross-entropy per positioncross-entropy on masked positionsMSE (continuous target)
Sample quality (image, matched compute)highestnearly AR, much fasterhighest on raw pixel diffusion; competitive in latent diffusion
Punchline
Discrete diffusion is “BERT-style masked prediction with a variable mask rate at training, plus an iterative top-K-confident unmask loop at sampling.” It runs in ~10 forwards instead of N, gives you bidirectional context, and supports editing for free. That’s the sampling backbone of most modern native-multimodal image generators.