English · Español
Lab 02 — Norm ablation: BatchNorm vs LayerNorm vs RMSNorm¶
Goal: show that LayerNorm and RMSNorm both train; RMSNorm is measurably cheaper.
Estimated time: 90–120 minutes.
Prereq: lab 01 committed;
theory/02-normalization.mdread.
What you produce¶
A directory experiments/10-norm-ablation/ containing:
train.py— training script with a--normCLI flag.losses.json— three loss trajectories (no-norm, LayerNorm, RMSNorm).timings.json— per-step wall-clock time for each norm variant.loss_curves.png— three curves.timing_bar.png— bar chart of mean step time per variant.manifest.json.README.md.
The setup¶
Take the 12-layer MLP from lab 01. Fix init to Kaiming (the one that trains). Vary only the normalization:
- No norm (baseline; should still train with Kaiming, just less stably).
- LayerNorm before each ReLU.
- RMSNorm before each ReLU.
You're showing two things:
a. Both LayerNorm and RMSNorm train, and produce similar final loss. b. RMSNorm is faster per step. Measurable on Borja's i5-8250U for a 12-layer net at hidden 256.
TODOs¶
Block A — implement the norms¶
- Write
src/minigrad/nn/norm.pywithLayerNorm(d)andRMSNorm(d)modules. - LayerNorm: \(\mu = \text{mean}(x, \text{axis=-1})\), \(\sigma^2 = \text{var}(x, \text{axis=-1})\), \(\hat x = (x - \mu)/\sqrt{\sigma^2 + \epsilon}\), \(y = \gamma \hat x + \beta\).
- RMSNorm: \(\text{rms} = \sqrt{\text{mean}(x^2, \text{axis=-1}) + \epsilon}\), \(y = \gamma \cdot x / \text{rms}\).
- ε goes inside the sqrt. Unit test: norm of an all-zero input does not produce NaN.
- Both modules return Tensors that backprop correctly. Gradcheck both.
Block B — three runs¶
- Same 12-layer MLP. Kaiming init.
- Variant 1: no norm.
- Variant 2: LayerNorm before each ReLU. Apply Pre-style:
out = ReLU(LayerNorm(x)). - Variant 3: RMSNorm before each ReLU. Same shape.
- Time each step with
time.perf_counter_ns().
Block C — plot¶
- Loss curves (3 lines).
- Bar chart of mean step time (3 bars). Include error bars from the timing's standard deviation.
Block D — interpret¶
In README.md:
- Do all three train? Yes/no per variant. (Yes/yes/yes is expected with Kaiming.)
- Is the final loss the same? Quantify.
- Is RMSNorm faster than LayerNorm? By what percentage? Match against theory 02's prediction (~30–40% faster).
- What's the no-norm baseline's loss trajectory shape? Smooth or jittery? Why?
Block E — manifest¶
Standard. Include the epsilon value and the norm placement (Pre vs Post).
Constraints¶
- Pre-style only. Not testing Pre vs Post here (that's effectively lab 03 indirectly via residual + norm).
mypy --strictonsrc/minigrad/nn/norm.py.- Property test: for random fp32 inputs of various shapes, RMSNorm output's
mean(x²)per-row equalsmean(γ²)(since output isγ · x / rmsandmean((γ·x/rms)²) = γ² · mean(x²)/rms² = γ² · 1 = mean(γ²)if γ is a vector of identical values; for non-uniform γ, check componentwise). - Single thread,
performancegovernor.
Stop conditions¶
Done when:
- All seven files are committed.
src/minigrad/nn/norm.pymypy --strictclean and gradcheck passes.- Loss curves show RMSNorm and LayerNorm converging to similar final loss.
- Timing bar shows RMSNorm strictly faster than LayerNorm (the absolute gap depends on your machine; should be > 0).
- README answers all four Block D questions.
Pitfalls¶
- LayerNorm without
γ, βlearnable. That's not LayerNorm; that's just standardization. The affine params are part of the definition. - RMSNorm with mean subtraction. That's LayerNorm. Common copy-paste bug.
- NaN on the first batch. Check ε placement (must be inside sqrt).
- RMSNorm slower than LayerNorm in your measurement. Likely a NumPy/minigrad overhead issue, not real cost. Profile a single norm op in isolation. The savings should be visible at hidden_dim ≥ 256.
- Gradcheck fails for RMSNorm. Often the closure for
1/rmsis missed. The backward through1/sqrt(mean(x²)+ε)is: $\(\frac{\partial}{\partial x_i} \frac{1}{\sqrt{m + \epsilon}} = -\frac{1}{2 (m+\epsilon)^{3/2}} \cdot \frac{2 x_i}{d} = -\frac{x_i}{d (m+\epsilon)^{3/2}}\)$ Make sure that term is in your backward.
Hint of last resort¶
If gradcheck for LayerNorm fails: the gradient of (x - mean(x)) / std(x) with respect to a single \(x_j\) depends on all \(x_i\) via the mean and the std. There are three pathways: through the explicit \(x_j\), through the mean, through the std. PyTorch's LayerNorm implementation has this as a closed-form expression; recompute it from the definition and you'll find your missing term.
When to consult solutions/¶
After all seven files. Solution: solutions/02-norm-ablation-ref.md (phase open).
Next lab: lab/03-residual-depth.md.