Skip to content

English · Español

Lab 02 — Norm ablation: BatchNorm vs LayerNorm vs RMSNorm

Goal: show that LayerNorm and RMSNorm both train; RMSNorm is measurably cheaper.

Estimated time: 90–120 minutes.

Prereq: lab 01 committed; theory/02-normalization.md read.


What you produce

A directory experiments/10-norm-ablation/ containing:

  • train.py — training script with a --norm CLI flag.
  • losses.json — three loss trajectories (no-norm, LayerNorm, RMSNorm).
  • timings.json — per-step wall-clock time for each norm variant.
  • loss_curves.png — three curves.
  • timing_bar.png — bar chart of mean step time per variant.
  • manifest.json.
  • README.md.

The setup

Take the 12-layer MLP from lab 01. Fix init to Kaiming (the one that trains). Vary only the normalization:

  1. No norm (baseline; should still train with Kaiming, just less stably).
  2. LayerNorm before each ReLU.
  3. RMSNorm before each ReLU.

You're showing two things:

a. Both LayerNorm and RMSNorm train, and produce similar final loss. b. RMSNorm is faster per step. Measurable on Borja's i5-8250U for a 12-layer net at hidden 256.

TODOs

Block A — implement the norms

  • Write src/minigrad/nn/norm.py with LayerNorm(d) and RMSNorm(d) modules.
  • LayerNorm: \(\mu = \text{mean}(x, \text{axis=-1})\), \(\sigma^2 = \text{var}(x, \text{axis=-1})\), \(\hat x = (x - \mu)/\sqrt{\sigma^2 + \epsilon}\), \(y = \gamma \hat x + \beta\).
  • RMSNorm: \(\text{rms} = \sqrt{\text{mean}(x^2, \text{axis=-1}) + \epsilon}\), \(y = \gamma \cdot x / \text{rms}\).
  • ε goes inside the sqrt. Unit test: norm of an all-zero input does not produce NaN.
  • Both modules return Tensors that backprop correctly. Gradcheck both.

Block B — three runs

  • Same 12-layer MLP. Kaiming init.
  • Variant 1: no norm.
  • Variant 2: LayerNorm before each ReLU. Apply Pre-style: out = ReLU(LayerNorm(x)).
  • Variant 3: RMSNorm before each ReLU. Same shape.
  • Time each step with time.perf_counter_ns().

Block C — plot

  • Loss curves (3 lines).
  • Bar chart of mean step time (3 bars). Include error bars from the timing's standard deviation.

Block D — interpret

In README.md:

  1. Do all three train? Yes/no per variant. (Yes/yes/yes is expected with Kaiming.)
  2. Is the final loss the same? Quantify.
  3. Is RMSNorm faster than LayerNorm? By what percentage? Match against theory 02's prediction (~30–40% faster).
  4. What's the no-norm baseline's loss trajectory shape? Smooth or jittery? Why?

Block E — manifest

Standard. Include the epsilon value and the norm placement (Pre vs Post).

Constraints

  • Pre-style only. Not testing Pre vs Post here (that's effectively lab 03 indirectly via residual + norm).
  • mypy --strict on src/minigrad/nn/norm.py.
  • Property test: for random fp32 inputs of various shapes, RMSNorm output's mean(x²) per-row equals mean(γ²) (since output is γ · x / rms and mean((γ·x/rms)²) = γ² · mean(x²)/rms² = γ² · 1 = mean(γ²) if γ is a vector of identical values; for non-uniform γ, check componentwise).
  • Single thread, performance governor.

Stop conditions

Done when:

  1. All seven files are committed.
  2. src/minigrad/nn/norm.py mypy --strict clean and gradcheck passes.
  3. Loss curves show RMSNorm and LayerNorm converging to similar final loss.
  4. Timing bar shows RMSNorm strictly faster than LayerNorm (the absolute gap depends on your machine; should be > 0).
  5. README answers all four Block D questions.

Pitfalls

  • LayerNorm without γ, β learnable. That's not LayerNorm; that's just standardization. The affine params are part of the definition.
  • RMSNorm with mean subtraction. That's LayerNorm. Common copy-paste bug.
  • NaN on the first batch. Check ε placement (must be inside sqrt).
  • RMSNorm slower than LayerNorm in your measurement. Likely a NumPy/minigrad overhead issue, not real cost. Profile a single norm op in isolation. The savings should be visible at hidden_dim ≥ 256.
  • Gradcheck fails for RMSNorm. Often the closure for 1/rms is missed. The backward through 1/sqrt(mean(x²)+ε) is: $\(\frac{\partial}{\partial x_i} \frac{1}{\sqrt{m + \epsilon}} = -\frac{1}{2 (m+\epsilon)^{3/2}} \cdot \frac{2 x_i}{d} = -\frac{x_i}{d (m+\epsilon)^{3/2}}\)$ Make sure that term is in your backward.

Hint of last resort

If gradcheck for LayerNorm fails: the gradient of (x - mean(x)) / std(x) with respect to a single \(x_j\) depends on all \(x_i\) via the mean and the std. There are three pathways: through the explicit \(x_j\), through the mean, through the std. PyTorch's LayerNorm implementation has this as a closed-form expression; recompute it from the definition and you'll find your missing term.

When to consult solutions/

After all seven files. Solution: solutions/02-norm-ablation-ref.md (phase open).


Next lab: lab/03-residual-depth.md.