Skip to content

English · Español

02 — Normalization: BatchNorm, LayerNorm, RMSNorm

🇪🇸 Normalizar = re-escalar (y a veces re-centrar) las activaciones para que la capa siguiente vea una distribución estable. BatchNorm lo hace sobre el batch (visión); LayerNorm sobre las features (transformers clásicos); RMSNorm es LayerNorm sin la media, y es lo que usan los LLMs modernos.


The thing they all solve

Once initialization is right (Phase 10 theory 01), forward-pass variance is roughly preserved at layer 0. After a few training steps, the weights move; the activation distribution drifts; the next layer sees a moving target. This is internal covariate shift — a paper's name for a real annoyance. Normalization fixes it by re-normalizing the activations every layer, regardless of how the weights have drifted.

There are three normalization variants you must know:

  • BatchNorm (Ioffe & Szegedy 2015) — vision.
  • LayerNorm (Ba et al. 2016) — transformers.
  • RMSNorm (Zhang & Sennrich 2019) — modern LLMs.

They differ in which axis they normalize over. The math is otherwise nearly identical.

BatchNorm

Given a batch of activations \(X \in \mathbb{R}^{B \times d}\) (batch × features), BatchNorm normalizes over the batch axis, per-feature:

\[ \mu_k = \frac{1}{B} \sum_{i=1}^{B} X_{ik}, \qquad \sigma_k^2 = \frac{1}{B} \sum_{i=1}^{B} (X_{ik} - \mu_k)^2 \]
\[ \hat{X}_{ik} = \frac{X_{ik} - \mu_k}{\sqrt{\sigma_k^2 + \epsilon}}, \qquad Y_{ik} = \gamma_k \hat{X}_{ik} + \beta_k \]

The learnable parameters \(\gamma, \beta \in \mathbb{R}^d\) let the network undo the normalization if it wants to (initialized \(\gamma = 1, \beta = 0\)).

Train vs inference. At train, \(\mu_k, \sigma_k\) come from the current batch. At inference, you'd like to be batch-size-invariant, so BN keeps running averages of \(\mu_k, \sigma_k\) during training and uses those at inference. This is the bug source: forgetting model.eval() is a top-10 deep-learning footgun. We mention this even though our curriculum won't use BN.

Why BN works on vision but not transformers. Vision batches are big, i.i.d., and features (spatial positions × channels) are roughly stationary across samples. Language batches are small, sequences are correlated within an example, and feature semantics shift across sequence positions. BatchNorm's batch-statistics noise dominates the signal for language. The transformer paper switched to LayerNorm for that reason.

LayerNorm

LayerNorm normalizes over the feature axis, per-sample:

\[ \mu^{(i)} = \frac{1}{d} \sum_{k=1}^{d} X_{ik}, \qquad {\sigma^{(i)}}^2 = \frac{1}{d} \sum_{k=1}^{d} (X_{ik} - \mu^{(i)})^2 \]
\[ \hat{X}_{ik} = \frac{X_{ik} - \mu^{(i)}}{\sqrt{{\sigma^{(i)}}^2 + \epsilon}}, \qquad Y_{ik} = \gamma_k \hat{X}_{ik} + \beta_k \]

Note the index. \(\mu, \sigma\) depend on sample i, not on the batch. So:

  • Same computation at train and inference. No .eval() switch needed. No bug source.
  • No batch dependence. Works with batch size 1 (inference, chat).

LayerNorm became the transformer's default, full stop, for a decade.

RMSNorm

In 2019, Zhang & Sennrich proposed dropping the mean subtraction:

\[ \text{RMS}(X^{(i)}) = \sqrt{\frac{1}{d} \sum_{k=1}^{d} X_{ik}^2 + \epsilon} \]
\[ Y_{ik} = \gamma_k \cdot \frac{X_{ik}}{\text{RMS}(X^{(i)})} \]

No mean. No \(\beta\) bias parameter (some implementations keep \(\beta\), most don't). Two effects:

  1. Half the FLOPs and half the memory traffic. No mean computation, no mean subtraction, no \(\beta\). On Borja's i5-8250U, expect ~30–40% wall-time reduction for the norm op itself.
  2. Marginal-or-better training stability empirically. The original RMSNorm paper showed comparable convergence; subsequent LLM work (Llama, T5-1.1, PaLM) confirmed: removing the mean doesn't hurt.

Why doesn't removing the mean hurt? No definitive proof. The empirical argument: after a residual sum x + f(x) with mean-zero f, the mean of x + f(x) is approximately the mean of x — already controlled by the previous layer's normalization. Re-centering every layer is redundant work. The re-scaling is what actually matters for the next layer's variance preservation.

We accept this empirically. RMSNorm is the default in modern LLM stacks; we implement it as the curriculum's primary norm.

Pre-LN vs Post-LN

Where you put the normalization in a residual block matters.

Post-LN (original transformer): $$ y = \text{LN}(x + f(x)) $$

Pre-LN (modern transformer): $$ y = x + f(\text{LN}(x)) $$

Pre-LN became dominant because it gives a cleaner gradient path:

  • In Pre-LN, the gradient ∂y/∂x = I + ∂(f ∘ LN)/∂x. The identity term is unconditional — the residual highway always has unit gain.
  • In Post-LN, the gradient ∂y/∂x = ∂LN/∂(x+f(x)) × (I + ∂f/∂x). The LayerNorm Jacobian scales the residual term; for deep networks this multiplicative factor compounds and destabilizes training.

The empirical consequence: Pre-LN transformers train without learning-rate warmup; Post-LN transformers need warmup to avoid early divergence. Phase 17 (transformer blocks) goes deeper. For Phase 10, the rule is: use Pre-LN in everything we build.

ε placement (Numerical Stability subagent territory)

The denominator is \(\sqrt{\sigma^2 + \epsilon}\), NOT \(\sqrt{\sigma^2} + \epsilon\). Why:

  • For very small \(\sigma^2 \to 0\), \(\sqrt{\sigma^2} \to 0\), so the division blows up before \(\epsilon\) can save you.
  • \(\sqrt{\sigma^2 + \epsilon}\) never goes below \(\sqrt{\epsilon}\) — the \(\epsilon\) acts as a variance floor, which is the intended semantics.

A common implementation bug is 1.0 / (np.sqrt(var) + eps). This passes tests with non-degenerate inputs and fails silently on edge cases. Lab 02 has a unit test that triggers \(\sigma \to 0\) to expose this.

Typical \(\epsilon\) values: \(10^{-5}\) (LayerNorm), \(10^{-6}\) (RMSNorm). The LLM ecosystem isn't religious about this; check the model card.

A sanity diagram

                BatchNorm                LayerNorm / RMSNorm
                ─────────                ─────────────────
                normalize OVER:           normalize OVER:
                ┌──────────┐              ┌──────────┐
                │  Batch   │              │ Features │
                └──────────┘              └──────────┘

X.shape = (B, d)                       X.shape = (B, d)
μ.shape = (d,)                          μ.shape = (B,)
σ.shape = (d,)                          σ.shape = (B,)

The shapes alone tell you the norm: if your μ has shape (d,), it's BatchNorm; (B,) means LayerNorm/RMSNorm. (Or, in a sequence model with shape (B, L, d), LayerNorm has μ.shape = (B, L).)

Computational cost — actual numbers

For shape (B, L, d) = (64, 512, 768) (typical small transformer training step):

  • LayerNorm: 1 mean reduction + 1 variance reduction + 1 subtract + 1 divide + 1 affine = ~5d FLOPs per element.
  • RMSNorm: 1 mean-of-squares reduction + 1 divide + 1 scale = ~3d FLOPs per element.

Ratio: RMSNorm does about 60% of LayerNorm's compute. Memory traffic ratio similar. On Phase 1's roofline argument (norm ops are memory-bound), the savings translate to wall-clock ~30–40%.

What this phase doesn't implement

  • GroupNorm — between BN and LN; used in some vision-language models. Skip.
  • InstanceNorm — used in style transfer. Skip.
  • Weight standardization — orthogonal trick that normalizes weights, not activations. Phase 21.
  • DyT (dynamic tanh) — a 2024+ alternative to LayerNorm. Skim if curious; not in scope.

Drill problems

Solutions in solutions/02-normalization-ref.md (phase open).

  1. Show that for LayerNorm with \(d = 1\), the output is always exactly \(\beta_1\) (regardless of input). Why is d = 1 LayerNorm useless?
  2. Show that the gradient of LayerNorm with respect to a single input \(X_{ik}\) depends on all other features \(X_{ij}\) (j ≠ k) via the mean and variance. Contrast this with RMSNorm, where the dependency structure is the same (still all features) but the expression is simpler.
  3. A network uses Post-LN with no warmup at learning rate 1e-3 and diverges in 100 steps. Suggest two single-change fixes (other than reducing LR) and explain which failure mode each addresses.

One-paragraph recap

Normalization re-centers and/or re-scales activations so the next layer sees a stable distribution. BatchNorm normalizes over the batch axis (vision; train/inference divergence). LayerNorm normalizes over features per sample (transformers; no train/inference divergence). RMSNorm drops the mean subtraction from LayerNorm (modern LLMs; ~40% cheaper, empirically equivalent stability). Pre-LN (y = x + f(LN(x))) is the modern default because it gives a cleaner gradient highway than Post-LN. \(\epsilon\) goes inside the sqrt for numerical stability.


Next: theory/03-residuals.md.