deep_learning / 06 · transformer block lesson 6 / 12

The transformer block — full forward pass

Assemble normalization, attention, positional encoding, and a feedforward block into the single repeating unit. Count its parameters, count its FLOPs, count its memory. Once you can do that on paper, "how big can I serve?" stops being a question for the systems team.

The full forward pass — pre-norm, modern style

For input residual stream x ∈ ℝ^{B × N × d}:

x' = x + Attention(RMSNorm(x)) x'' = x' + FFN(RMSNorm(x'))

Two sub-blocks, each with its own norm, each adding a residual contribution to the stream. That's the entire transformer block. Stack L of them, prepend an embedding, append a final norm + LM head, and you have a language model.

┌─────────────────────────────────┐ ← block input x_l │ │ │ ┌─────────┐ │ │ ┌─▶ │ RMSNorm │ ─▶ MHA ─┐ │ │ │ └─────────┘ ▼ │ │ └──────────────────────(+)│ ← residual add │ │ │ ┌─────────┐ │ │ ┌─▶ │ RMSNorm │ ─▶ FFN ─┐ │ │ │ └─────────┘ ▼ │ │ └──────────────────────(+)│ ← residual add │ │ └─────────────────────────────────┘ ← block output x_{l+1}

The FFN — wider than you'd guess, and why

The feedforward block (originally in "Attention Is All You Need") is two linear layers with a non-linearity in between:

FFN(x) = W₂ · σ(W₁ x)   — W₁: (d, 4d), W₂: (4d, d)

The hidden dim is conventionally 4d. Three reasons that ratio survives:

  1. Parameter accounting. Attention has ~4d² params (QKV + O). FFN has ~8d² params (W₁ + W₂ at 4d hidden). So FFN is ~2× attention. Total per block: ~12d² params.
  2. Expressivity. The FFN is the only place each token gets non-linearly transformed. Attention is linear in V (V is the input; A is data-dependent but still a linear combination). The 4d hidden is enough capacity to absorb most of the model's "memory" — empirically, knowledge probes localise facts in FFN weights, not attention.
  3. Arithmetic intensity. FFN matmuls (B·N·d → B·N·4d → B·N·d) are big tensor-core friendly matmuls. They're the workhorse compute of the model. Attention has the awkward N×N shape.

SwiGLU — the gated FFN

Modern LLMs (LLaMA, PaLM, Mistral) replace the FFN with a gated variant: SwiGLU (Shazeer 2020). Instead of one hidden activation, compute two:

FFN(x) = W₃ · ( SiLU(W₁ x) ⊙ (W₂ x) )   — ⊙ is elementwise multiply SiLU(z) = z · σ(z) = z / (1 + e^{-z})

Three matrices instead of two: W₁, W₂ both (d, h_dim), W₃ is (h_dim, d). The gate SiLU(W₁ x) is multiplied elementwise by W₂ x, then the result is projected back.

The FLOP / parameter accounting:

FFNParamsFLOPs / tokenComment
Vanilla (4d hidden)8d²16d² (2 matmuls)Baseline
SwiGLU (4d hidden, 3 matmuls)12d²24d²1.5× cost vs vanilla
SwiGLU (8/3·d ≈ 2.67d hidden, to match params)~8d²~16d²"SwiGLU at iso-params"

The modern recipe: SwiGLU with hidden dim 8d/3 (rounded to a multiple of 256 for tensor core friendliness) — same parameter count and FLOPs as vanilla FFN, but typically 0.5–1% better perplexity.

The gating intuition
Vanilla FFN: every channel of the hidden state contributes to the output, weighted by W₂. SwiGLU: each channel has its own data-dependent gate. The model can learn to "turn off" channels it doesn't need for the current token. This is a soft mixture-of-paths and it costs negligible extra compute when sized for iso-params.

Parameter count of one block — derive this

For a transformer block with hidden dim d, heads h, FFN hidden d_f:

ComponentParametersFor d=8192, d_f=28672 (LLaMA-2-70B layer, real values)
Q projection W_Q67M
K projection W_K (GQA: d²/8 if g=8)d²/8 (GQA)8.4M
V projection W_V (same as K)d²/8 (GQA)8.4M
Output projection W_O67M
FFN W₁ (SwiGLU gate)d · d_f235M
FFN W₂ (SwiGLU up)d · d_f235M
FFN W₃ (down)d_f · d235M
2 × RMSNorm γ2d16k (negligible)
Total per block~2.25d² + 3·d·d_f (GQA-8, SwiGLU)~856M

Multiply by L=80 layers: ~68B params in the body. Plus embedding (V=32k × d=8192 ≈ 262M) + final norm + LM head (262M, untied): ~0.5B more. Total ~69B → 70B. The arithmetic works.

Weight tying — read the spec carefully
Some models tie the embedding matrix and the LM-head weights (they're both (vocab, d)). Modern LLMs typically do not tie (separate embedding and output head). Tying saves ~vocab·d parameters (~1.05B for 128k vocab, d=8192). The trade-off: tied weights constrain the input and output spaces to be the same — usually a mild regression in quality.

FLOP count per token — the 6N rule

For a forward pass at sequence length N, per token, per block:

ComponentFLOPs/tokenComment
QKV projections(2+2/g+2/g) · d² ≈ 2.5·d² (GQA g=8)"2×" for matmul = 1 MAC = 2 FLOP
Attention (Q·K, A·V)4 · N · dScales with N
Output projection2·d²Same as Q
FFN (SwiGLU)3 · 2 · d · d_f = 6·d·d_f~16·d² at d_f=8d/3
Total per layer per token~24·d² + 4·N·d

For a full forward + backward, multiply by ~3 (forward + 2× for backward gradients on both inputs and weights). For a full training step over D tokens:

FLOPs ≈ 6 · N_params · D

This is the "6N" rule. Memorise it — it's the load-bearing number for compute estimates.

Memory accounting — what fits where

Per token of activation in training (saved for backward, in bf16):

ActivationBytes / token / layerFor d=8192
Residual stream (x)2·d16 KB
RMSNorm output2·d16 KB
Q, K, V (or just QK for FlashAttention)~6·d (full) or ~2·d (FA)16–48 KB
Attention output2·d16 KB
FFN hidden (SwiGLU has two, fused stores one)2·d_f44 KB
Per layer subtotal (FlashAttention + fused)~10·d ≈ 80 KB

For training a 70B model at N=4k, B=1, in bf16, with FlashAttention and SwiGLU:

activations ≈ 80 layers × 4096 tokens × 80 KB ≈ 25 GB

Plus optimizer state (Adam ~16 bytes/param → 1.1 TB for 70B) → activation checkpointing + ZeRO-3 sharding mandatory.

The whole forward pass in one diagram

tokens ──▶ embedding (V × d) │ ▼ (with RoPE position info) ┌──────────────────┐ │ L × transformer │ ← stack of blocks, each: │ blocks │ pre-norm + GQA attn + residual │ │ pre-norm + SwiGLU + residual └──────────────────┘ │ ▼ final RMSNorm │ ▼ LM head (d × V) │ ▼ logits ──▶ softmax ──▶ p(next token)

Interactive · compute / memory of a model

Transformer block calculator
Set the architecture knobs. The widget reports param count, FLOPs per token, training memory, and rough wall-clock for Chinchilla-optimal training on N H100s.
total params
FLOPs per token (fwd)
Chinchilla tokens (20×P)
training wall-clock
Reading

The interview probes

  1. Why is the FFN ~2× the attention parameter count? Attention has 4 matrices of (d, d) each → 4d². FFN has 3 matrices of (d, 8d/3) → 8d². Roughly 2:1 split. This means most of the model's parameters are in the FFN, even though attention does the position-mixing.
  2. Why does the FFN need a wider hidden than the residual stream? Because the FFN is a Universal Function Approximator that has to capture per-token non-linearities. The 4× factor (or 8/3 with SwiGLU) is empirically what fits memorisation tasks. Smaller doesn't have enough capacity; larger gives diminishing returns.
  3. Why is the residual stream not normalised in pre-norm models? Because the residual is the "highway" the gradient flows on. Normalising it would scale the gradient with the residual's local variance, which is what pre-norm is designed to avoid.
  4. What's the role of W_O in attention? It mixes information across heads. Each head produces a d/h-dim output independently; concatenation gives d. W_O then linearly combines the heads. Without W_O, each subsequent layer sees only the head-wise concatenation; the model can't compose heads.
  5. How many parameters in the embedding table vs the body of a LLaMA-2-70B? Embedding: 32k × 8192 ≈ 262M (+ untied LM head, another 262M). Body: 80 layers × ~856M ≈ 68B. Body dominates by ~250×. For very small models (~1B), embeddings can be ~25% of parameters — there the choice of vocab size matters disproportionately.

Common transformer block design choices in 2026

ChoiceModern defaultWhy
Norm locationPre-normTrains deep without warmup tricks
Norm typeRMSNormCheaper, bf16-friendlier, no quality loss
Attention variantGQA (g=8)8× KV cache reduction, ~no quality loss
Positional encodingRoPERelative position for free, extension-friendly
FFNSwiGLU at d_f = 8d/3Gated mixture, iso-params
ActivationSiLU (or GeLU)Smooth, no dead-ReLU
Bias termsNone on Q,K,V,O,FFNDon't help; cost params; some bf16 instability
Tied embeddingsNoSmall quality gain for 1.6B extra params; usually keep separate

Interview prompts you should be ready for

  1. "Derive the parameter count of LLaMA-2-7B given d=4096, L=32, h=32, d_f=11008, V=32000." (Embedding: 4096 × 32000 = 131M. Per block: 4·d² + 3·d·d_f = 67M + 135M ≈ 202M. Body: 32 × 202M = 6.5B. LM head (untied): 131M. Final norm: negligible. Total: 6.7B. Matches.)
  2. "You're told a 70B model takes 6 weeks on 1k H100s. Sanity check." (6N·D rule. Chinchilla: D = 20·N = 1.4T tokens. FLOPs = 6 × 70e9 × 1.4e12 = 5.9e23. H100 bf16 peak = 990 TFLOPs; MFU ~45% → 450 TFLOPs/GPU sustained. 1000 GPUs × 450 TFLOPs × 86400 s/day × 42 days = 1.6e24 → enough. The 6 weeks works at MFU 35% or so. Numbers check.)
  3. "Where in the transformer block is most of the FLOPs spent?" (For short sequences: FFN dominates (it's ~2× attention's projection cost in params/FLOPs). The attention N²·d term catches up around N ≈ 2d. At LLaMA-2 4k context with d=8192, FFN still dominates. At long context (N ≥ 32k for this model), attention dominates.)
  4. "Why don't modern LLMs use biases?" (Three reasons: 1) bias adds parameters with no expressivity gain — the residual stream and norm together can shift any constant. 2) bf16 with bias can amplify catastrophic cancellation in the residual sum. 3) Empirical: removing biases is neutral or slightly positive on perplexity.)
  5. "You want to scale your 7B model to 70B. What stays the same and what changes?" (Layers scale up (~32 → 80), d scales up (4096 → 8192), heads scale up. d_f scales with d. Most modern recipes (μP) say LR stays roughly fixed; batch size scales with model size. Total flops scales as N². Training time scales as N²/H (number of GPUs).)
  6. "What's a 'residual stream'?" (The path of accumulated activations from layer to layer, through the additive residual connections. Each block reads from it (via norm) and adds back to it. Mechanistic interpretability calls this the "communication bus" — circuits read/write specific features at specific positions of the residual stream.)
Takeaway
A transformer block is two sub-blocks (attention + FFN), each a residual addition to the stream, each with its own pre-norm. Modern recipe: RMSNorm, GQA, RoPE, SwiGLU, no bias. Parameter count per block ~12d² for vanilla, ~6d²+3d·d_f for GQA+SwiGLU. The 6N rule estimates training compute in 30 seconds. Memory accounting tells you whether the model fits. These are the back-of-envelope skills every senior interview tests.