DiT — transformer denoiser
Patch the image, run attention between patches, condition on time via adaLN-Zero. Why this is the modern default for both DDPM and FM.
The minimal moves
A diffusion / flow-matching model is any function fθ(x, t) with the same shape on the output as on x. For 2D points the MLP from lessons 2–6 is enough. For images, we want a function class that:
- Treats the image as a sequence of patches (token structure that attention can mix), not raw pixels.
- Lets any patch attend to any other patch in a single layer — far-flung structure (e.g. circles spanning the whole image) shouldn’t require a deep stack of convs to integrate.
- Conditions on t in a way that’s strong enough to modulate every layer, but cheap and stable to train.
DiT (Peebles & Xie 2023) does this with three components: a patch embedding, an LLM-style transformer body with bidirectional attention, and the adaLN-Zero conditioning trick. diffusion_transformer.py is a faithful, pedagogically-stripped implementation.
Shape flow at a glance
Patch embedding — three steps fused into one GEMM
Conceptually, patch embedding is:
- Reshape (B, C, H, W) into non-overlapping patches of size p × p: (B, C, H/p, p, W/p, p).
- Flatten each patch to a vector of size p²·C: (B, N, p²·C) where N = (H/p)².
- Apply a Linear from p²·C to d: (B, N, d).
A Conv2d with kernel_size = stride = p does all three in one GEMM (PatchEmbed). Mathematically identical, more efficient on GPU.
For img_size=16, patch_size=4 (this repo’s defaults): N = (16/4)² = 16 tokens per image, embedding dim d = 64. Sixteen tokens is exactly enough for attention to do something interesting without making the toy slow.
Attention — bidirectional, no mask
Standard transformer self-attention, but no causal mask. Each patch attends to every other patch. For images this is the correct inductive bias: pixels at the top can correlate with pixels at the bottom, and the model shouldn’t have to learn that connection through three convs and a downsample first.
Full attention matrix — every patch’s view at once
The previous widget shows one patch’s attention. Here’s the full N × N matrix: rows are query patches, columns are keys. Bright entry (i, j) means patch i looks at patch j. Click a column or row to highlight the corresponding patch on the image.
adaLN-Zero — the conditioning trick
The hardest part of a denoiser/velocity-net is conditioning on t. The signal is just a scalar, but it has to modulate everything: at t = 0 the net should pass x through almost unchanged; at t = T the net should aggressively predict structure.
Three reasonable approaches:
| Approach | Cost | Issues |
|---|---|---|
| Concat (B, d) time embedding as extra token | +1 token per block | attention has to discover that token is special; conditioning is “soft” |
| FiLM at every layer | 2·d params per layer | works, but no principled init — deep stacks can be unstable |
| adaLN-Zero | 6·d params per block | each block starts as identity ⇒ stable training without warmup |
The adaLN-Zero formula, for one DiT block:
where f1 is attention, f2 is the MLP, and (scalei, shifti, gatei) come from a small MLP applied to the time embedding c.
- 1 + scale = 1 ⇒ LayerNorm is unmodified
- shift = 0 ⇒ no additive change
- gate = 0 ⇒ the residual contribution is multiplied by zero; the block degenerates to the identity
output = x. Each block learns how much to contribute over training, gated by its own gate. Same effect as careful residual-scale inits (Fixup, &c.) but more interpretable.
Interactive · watch the adaLN-Zero gates open
Below is a tiny simulator of an adaLN-Zero training trajectory. We don’t train a real network; we just animate gate magnitudes growing from zero as training progresses, and show what the predicted output looks like (identity at step 0, eventually some shaped output).
What changes between MLP and DiT
| Axis | MLP (2D toy) | DiT (image) |
|---|---|---|
| Net class | MLPDenoiser / MLPVelocity | DiT |
| Input/output shape | (B, 2) | (B, 1, 16, 16) |
| Time conditioning | concat | adaLN-Zero |
| Parameter count (this repo) | ~25k | ~250k (d=64, depth=4) |
| DDPM / FM wrapper | unchanged | unchanged |
Crucially, the wrapper (DDPM, FlowMatching) is shape-agnostic. The same DDPM.loss(x) and DDPM.sample(n, shape=…) work for both 2D and image data. q_sample uses while ab.dim() < x0.dim(): ab = ab.unsqueeze(-1) to broadcast the schedule scalar across whatever trailing dimensions the data has. That is the entire mechanism that lets the 2D and image cases share the same code.
Why DiT vs. UNet?
| UNet | DiT | |
|---|---|---|
| Strong locality bias? | yes (convs) | no (attention is content-based, not spatial) |
| Long-range receptive field? | through downsample + conv stack | one attention layer |
| Conditioning style | FiLM, cross-attn, ad-hoc | adaLN-Zero (uniform across blocks) |
| Scales with (depth, width)? | middling — bottleneck dominates | clean LLM-like curves |
| Best when | data is small or strongly spatially local | data is large, structure is global, you want one architecture across modalities |
Concrete recent evidence: every big generative-image model since 2023 (Stable Diffusion 3, Flux, Sora, Veo) uses a DiT or DiT-derived backbone. UNet still wins at small scale and for very local denoising (e.g. very low-resolution toys with limited compute).