generative_continuous / 08 · DiT lesson 8 / 15

DiT — transformer denoiser

Patch the image, run attention between patches, condition on time via adaLN-Zero. Why this is the modern default for both DDPM and FM.

The minimal moves

A diffusion / flow-matching model is any function fθ(x, t) with the same shape on the output as on x. For 2D points the MLP from lessons 2–6 is enough. For images, we want a function class that:

  1. Treats the image as a sequence of patches (token structure that attention can mix), not raw pixels.
  2. Lets any patch attend to any other patch in a single layer — far-flung structure (e.g. circles spanning the whole image) shouldn’t require a deep stack of convs to integrate.
  3. Conditions on t in a way that’s strong enough to modulate every layer, but cheap and stable to train.

DiT (Peebles & Xie 2023) does this with three components: a patch embedding, an LLM-style transformer body with bidirectional attention, and the adaLN-Zero conditioning trick. diffusion_transformer.py is a faithful, pedagogically-stripped implementation.

Shape flow at a glance

x_t  (B, C, H, W) noisy image PatchEmbed: Conv2d(p,p) tokens (B, N, d) N = (H/p)² · (1, p²·C → d) + pos_embed (learned) t (B,) scalar time c = t_embed (B, d) SinEmb → 2-layer MLP DiT block × L x ← x + g₁·Attn((1+s₁)·LN(x)+sh₁) x ← x + g₂·MLP ((1+s₂)·LN(x)+sh₂) (s₁, sh₁, g₁, s₂, sh₂, g₂) = chunk(Linear(SiLU(c)), 6) last Linear in modulator zero-init final adaLN (2·d: scale, shift only) + Linear (zero-init) no gate — there’s no residual left to gate · (B, N, d) → (B, N, p²·C) unpatch → (B, C, H, W) predicted ε or v

Patch embedding — three steps fused into one GEMM

Conceptually, patch embedding is:

  1. Reshape (B, C, H, W) into non-overlapping patches of size p × p: (B, C, H/p, p, W/p, p).
  2. Flatten each patch to a vector of size p²·C: (B, N, p²·C) where N = (H/p)².
  3. Apply a Linear from p²·C to d: (B, N, d).

A Conv2d with kernel_size = stride = p does all three in one GEMM (PatchEmbed). Mathematically identical, more efficient on GPU.

For img_size=16, patch_size=4 (this repo’s defaults): N = (16/4)² = 16 tokens per image, embedding dim d = 64. Sixteen tokens is exactly enough for attention to do something interesting without making the toy slow.

Attention — bidirectional, no mask

Standard transformer self-attention, but no causal mask. Each patch attends to every other patch. For images this is the correct inductive bias: pixels at the top can correlate with pixels at the bottom, and the model shouldn’t have to learn that connection through three convs and a downsample first.

Attention sparsity — what does each patch look at?
Click a patch (orange highlight) to see a synthetic attention pattern produced from a trained DiT-style model on a 16×16 circle image. Patches inside the circle attend to other in-circle patches; background patches attend to themselves and their neighbours.

The left grid is a 16-token image (4×4 patches at p=4 on a 16×16 image). The right grid is the attention pattern from the selected token (orange box on the left). Brighter = more attention. This is illustrative, not from a real model — the point is to show the “every patch reaches everywhere” freedom.

Full attention matrix — every patch’s view at once

The previous widget shows one patch’s attention. Here’s the full N × N matrix: rows are query patches, columns are keys. Bright entry (i, j) means patch i looks at patch j. Click a column or row to highlight the corresponding patch on the image.

N × N attention matrix on the 16-patch image
Click any cell. The corresponding query patch (row) and key patch (column) light up on the 4×4 image. The diagonal is always bright — every patch attends to itself most.

Synthetic — based on a content-similarity + distance kernel, scaled by a sharpness knob (note: not the conventional softmax temperature, which is divisive; this one is multiplicative on the distance penalty, so larger sharpness = more concentrated). Softened with depth. Real DiT attention has rich layer-dependent structure; this widget’s point is the topology: every cell is non-zero, every patch can reach every other in a single layer.

adaLN-Zero — the conditioning trick

The hardest part of a denoiser/velocity-net is conditioning on t. The signal is just a scalar, but it has to modulate everything: at t = 0 the net should pass x through almost unchanged; at t = T the net should aggressively predict structure.

Three reasonable approaches:

ApproachCostIssues
Concat (B, d) time embedding as extra token+1 token per blockattention has to discover that token is special; conditioning is “soft”
FiLM at every layer2·d params per layerworks, but no principled init — deep stacks can be unstable
adaLN-Zero6·d params per blockeach block starts as identity ⇒ stable training without warmup

The adaLN-Zero formula, for one DiT block:

x ← x + gatei ⊙ fi( (1 + scalei) ⊙ LayerNorm(x) + shifti )

where f1 is attention, f2 is the MLP, and (scalei, shifti, gatei) come from a small MLP applied to the time embedding c.

The “-Zero” trick
Initialize the modulation MLP’s output layer to zero. At step 0: The whole network starts as output = x. Each block learns how much to contribute over training, gated by its own gate. Same effect as careful residual-scale inits (Fixup, &c.) but more interpretable.

Interactive · watch the adaLN-Zero gates open

Below is a tiny simulator of an adaLN-Zero training trajectory. We don’t train a real network; we just animate gate magnitudes growing from zero as training progresses, and show what the predicted output looks like (identity at step 0, eventually some shaped output).

adaLN-Zero gate magnitudes over training
L blocks. Each block has gate values (one per (scale, shift, gate) channel) initialized to 0. Watch them open up over “training”.
training step
0
mean |gate|
0.00
effective network depth
0.0

What changes between MLP and DiT

AxisMLP (2D toy)DiT (image)
Net classMLPDenoiser / MLPVelocityDiT
Input/output shape(B, 2)(B, 1, 16, 16)
Time conditioningconcatadaLN-Zero
Parameter count (this repo)~25k~250k (d=64, depth=4)
DDPM / FM wrapperunchangedunchanged

Crucially, the wrapper (DDPM, FlowMatching) is shape-agnostic. The same DDPM.loss(x) and DDPM.sample(n, shape=…) work for both 2D and image data. q_sample uses while ab.dim() < x0.dim(): ab = ab.unsqueeze(-1) to broadcast the schedule scalar across whatever trailing dimensions the data has. That is the entire mechanism that lets the 2D and image cases share the same code.

Why DiT vs. UNet?

UNetDiT
Strong locality bias?yes (convs)no (attention is content-based, not spatial)
Long-range receptive field?through downsample + conv stackone attention layer
Conditioning styleFiLM, cross-attn, ad-hocadaLN-Zero (uniform across blocks)
Scales with (depth, width)?middling — bottleneck dominatesclean LLM-like curves
Best whendata is small or strongly spatially localdata is large, structure is global, you want one architecture across modalities

Concrete recent evidence: every big generative-image model since 2023 (Stable Diffusion 3, Flux, Sora, Veo) uses a DiT or DiT-derived backbone. UNet still wins at small scale and for very local denoising (e.g. very low-resolution toys with limited compute).

Punchline
DiT is “ε_θ / v_θ as a vision transformer.” PatchEmbed turns image into tokens; bidirectional attention mixes them globally; adaLN-Zero conditions on time at every block with a clean identity init. The loss machinery is unchanged. That clean factorization — architecture-independent objective, objective-independent architecture — is the takeaway.