The transformer block — full forward pass
Assemble normalization, attention, positional encoding, and a feedforward block into the single repeating unit. Count its parameters, count its FLOPs, count its memory. Once you can do that on paper, "how big can I serve?" stops being a question for the systems team.
The full forward pass — pre-norm, modern style
For input residual stream x ∈ ℝ^{B × N × d}:
Two sub-blocks, each with its own norm, each adding a residual contribution to the stream. That's the entire transformer block. Stack L of them, prepend an embedding, append a final norm + LM head, and you have a language model.
The FFN — wider than you'd guess, and why
The feedforward block (originally in "Attention Is All You Need") is two linear layers with a non-linearity in between:
The hidden dim is conventionally 4d. Three reasons that ratio survives:
- Parameter accounting. Attention has ~4d² params (QKV + O). FFN has ~8d² params (W₁ + W₂ at 4d hidden). So FFN is ~2× attention. Total per block: ~12d² params.
- Expressivity. The FFN is the only place each token gets non-linearly transformed. Attention is linear in V (V is the input; A is data-dependent but still a linear combination). The 4d hidden is enough capacity to absorb most of the model's "memory" — empirically, knowledge probes localise facts in FFN weights, not attention.
- Arithmetic intensity. FFN matmuls (B·N·d → B·N·4d → B·N·d) are big tensor-core friendly matmuls. They're the workhorse compute of the model. Attention has the awkward N×N shape.
SwiGLU — the gated FFN
Modern LLMs (LLaMA, PaLM, Mistral) replace the FFN with a gated variant: SwiGLU (Shazeer 2020). Instead of one hidden activation, compute two:
Three matrices instead of two: W₁, W₂ both (d, h_dim), W₃ is (h_dim, d). The gate SiLU(W₁ x) is multiplied elementwise by W₂ x, then the result is projected back.
The FLOP / parameter accounting:
| FFN | Params | FLOPs / token | Comment |
|---|---|---|---|
| Vanilla (4d hidden) | 8d² | 16d² (2 matmuls) | Baseline |
| SwiGLU (4d hidden, 3 matmuls) | 12d² | 24d² | 1.5× cost vs vanilla |
| SwiGLU (8/3·d ≈ 2.67d hidden, to match params) | ~8d² | ~16d² | "SwiGLU at iso-params" |
The modern recipe: SwiGLU with hidden dim 8d/3 (rounded to a multiple of 256 for tensor core friendliness) — same parameter count and FLOPs as vanilla FFN, but typically 0.5–1% better perplexity.
Parameter count of one block — derive this
For a transformer block with hidden dim d, heads h, FFN hidden d_f:
| Component | Parameters | For d=8192, d_f=28672 (LLaMA-2-70B layer, real values) |
|---|---|---|
| Q projection W_Q | d² | 67M |
| K projection W_K (GQA: d²/8 if g=8) | d²/8 (GQA) | 8.4M |
| V projection W_V (same as K) | d²/8 (GQA) | 8.4M |
| Output projection W_O | d² | 67M |
| FFN W₁ (SwiGLU gate) | d · d_f | 235M |
| FFN W₂ (SwiGLU up) | d · d_f | 235M |
| FFN W₃ (down) | d_f · d | 235M |
| 2 × RMSNorm γ | 2d | 16k (negligible) |
| Total per block | ~2.25d² + 3·d·d_f (GQA-8, SwiGLU) | ~856M |
Multiply by L=80 layers: ~68B params in the body. Plus embedding (V=32k × d=8192 ≈ 262M) + final norm + LM head (262M, untied): ~0.5B more. Total ~69B → 70B. The arithmetic works.
FLOP count per token — the 6N rule
For a forward pass at sequence length N, per token, per block:
| Component | FLOPs/token | Comment |
|---|---|---|
| QKV projections | (2+2/g+2/g) · d² ≈ 2.5·d² (GQA g=8) | "2×" for matmul = 1 MAC = 2 FLOP |
| Attention (Q·K, A·V) | 4 · N · d | Scales with N |
| Output projection | 2·d² | Same as Q |
| FFN (SwiGLU) | 3 · 2 · d · d_f = 6·d·d_f | ~16·d² at d_f=8d/3 |
| Total per layer per token | ~24·d² + 4·N·d |
For a full forward + backward, multiply by ~3 (forward + 2× for backward gradients on both inputs and weights). For a full training step over D tokens:
This is the "6N" rule. Memorise it — it's the load-bearing number for compute estimates.
Memory accounting — what fits where
Per token of activation in training (saved for backward, in bf16):
| Activation | Bytes / token / layer | For d=8192 |
|---|---|---|
| Residual stream (x) | 2·d | 16 KB |
| RMSNorm output | 2·d | 16 KB |
| Q, K, V (or just QK for FlashAttention) | ~6·d (full) or ~2·d (FA) | 16–48 KB |
| Attention output | 2·d | 16 KB |
| FFN hidden (SwiGLU has two, fused stores one) | 2·d_f | 44 KB |
| Per layer subtotal (FlashAttention + fused) | ~10·d ≈ 80 KB |
For training a 70B model at N=4k, B=1, in bf16, with FlashAttention and SwiGLU:
Plus optimizer state (Adam ~16 bytes/param → 1.1 TB for 70B) → activation checkpointing + ZeRO-3 sharding mandatory.
The whole forward pass in one diagram
Interactive · compute / memory of a model
The interview probes
- Why is the FFN ~2× the attention parameter count? Attention has 4 matrices of (d, d) each → 4d². FFN has 3 matrices of (d, 8d/3) → 8d². Roughly 2:1 split. This means most of the model's parameters are in the FFN, even though attention does the position-mixing.
- Why does the FFN need a wider hidden than the residual stream? Because the FFN is a Universal Function Approximator that has to capture per-token non-linearities. The 4× factor (or 8/3 with SwiGLU) is empirically what fits memorisation tasks. Smaller doesn't have enough capacity; larger gives diminishing returns.
- Why is the residual stream not normalised in pre-norm models? Because the residual is the "highway" the gradient flows on. Normalising it would scale the gradient with the residual's local variance, which is what pre-norm is designed to avoid.
- What's the role of W_O in attention? It mixes information across heads. Each head produces a d/h-dim output independently; concatenation gives d. W_O then linearly combines the heads. Without W_O, each subsequent layer sees only the head-wise concatenation; the model can't compose heads.
- How many parameters in the embedding table vs the body of a LLaMA-2-70B? Embedding: 32k × 8192 ≈ 262M (+ untied LM head, another 262M). Body: 80 layers × ~856M ≈ 68B. Body dominates by ~250×. For very small models (~1B), embeddings can be ~25% of parameters — there the choice of vocab size matters disproportionately.
Common transformer block design choices in 2026
| Choice | Modern default | Why |
|---|---|---|
| Norm location | Pre-norm | Trains deep without warmup tricks |
| Norm type | RMSNorm | Cheaper, bf16-friendlier, no quality loss |
| Attention variant | GQA (g=8) | 8× KV cache reduction, ~no quality loss |
| Positional encoding | RoPE | Relative position for free, extension-friendly |
| FFN | SwiGLU at d_f = 8d/3 | Gated mixture, iso-params |
| Activation | SiLU (or GeLU) | Smooth, no dead-ReLU |
| Bias terms | None on Q,K,V,O,FFN | Don't help; cost params; some bf16 instability |
| Tied embeddings | No | Small quality gain for 1.6B extra params; usually keep separate |
Interview prompts you should be ready for
- "Derive the parameter count of LLaMA-2-7B given d=4096, L=32, h=32, d_f=11008, V=32000." (Embedding: 4096 × 32000 = 131M. Per block: 4·d² + 3·d·d_f = 67M + 135M ≈ 202M. Body: 32 × 202M = 6.5B. LM head (untied): 131M. Final norm: negligible. Total: 6.7B. Matches.)
- "You're told a 70B model takes 6 weeks on 1k H100s. Sanity check." (6N·D rule. Chinchilla: D = 20·N = 1.4T tokens. FLOPs = 6 × 70e9 × 1.4e12 = 5.9e23. H100 bf16 peak = 990 TFLOPs; MFU ~45% → 450 TFLOPs/GPU sustained. 1000 GPUs × 450 TFLOPs × 86400 s/day × 42 days = 1.6e24 → enough. The 6 weeks works at MFU 35% or so. Numbers check.)
- "Where in the transformer block is most of the FLOPs spent?" (For short sequences: FFN dominates (it's ~2× attention's projection cost in params/FLOPs). The attention N²·d term catches up around N ≈ 2d. At LLaMA-2 4k context with d=8192, FFN still dominates. At long context (N ≥ 32k for this model), attention dominates.)
- "Why don't modern LLMs use biases?" (Three reasons: 1) bias adds parameters with no expressivity gain — the residual stream and norm together can shift any constant. 2) bf16 with bias can amplify catastrophic cancellation in the residual sum. 3) Empirical: removing biases is neutral or slightly positive on perplexity.)
- "You want to scale your 7B model to 70B. What stays the same and what changes?" (Layers scale up (~32 → 80), d scales up (4096 → 8192), heads scale up. d_f scales with d. Most modern recipes (μP) say LR stays roughly fixed; batch size scales with model size. Total flops scales as N². Training time scales as N²/H (number of GPUs).)
- "What's a 'residual stream'?" (The path of accumulated activations from layer to layer, through the additive residual connections. Each block reads from it (via norm) and adds back to it. Mechanistic interpretability calls this the "communication bus" — circuits read/write specific features at specific positions of the residual stream.)