Skip to content

English · Español

Break — AdamW with weight_decay=0 vs the correct value

🇪🇸 Apagamos el weight decay y observamos cómo el modelo memoriza el corpus §A13 más rápido y generaliza peor. Es la prueba más limpia de que el decay no es "regularización opcional" — es lo que mantiene los pesos en el régimen donde el optimizador estima v_t con utilidad.


Symptom Borja will see

Two training runs with identical seed, schedule, batch size, and architecture. Only one config differs:

  • Run A (control): weight_decay = 0.1 (the §A13 recommended value).
  • Run B (break): weight_decay = 0.0.

After 2000 steps:

  • Run A: train loss \(\approx 1.85\), val loss \(\approx 2.05\), gap \(\approx 0.20\).
  • Run B: train loss \(\approx 1.55\), val loss \(\approx 2.40\), gap \(\approx 0.85\).

The dashboard panel "train vs val loss" shows the two curves diverging from ~step 600 onward in Run B; in Run A they track each other within 0.25 throughout. Run B's weight-norm panel shows the embedding-table Frobenius norm climbing monotonically — by step 2000 it's 2.5× the initial value. Run A's stays within 1.2× of init.

The break, mechanically

In experiments/18-break-weight-decay/config.yaml:

# Run A (control)
optimizer:
  name: adamw
  weight_decay: 0.1

# Run B (the break)
optimizer:
  name: adamw
  weight_decay: 0.0   # <-- THIS LINE

Or equivalently in code: pass weight_decay=0.0 to the AdamW constructor in src/minitrain/loop.py.

No other changes. The whole break is one number.

Why this teaches the concept

At §A13 scale, the train set is 240 verb-form sentences. The model has ~103k parameters. Without any regularization, the optimization landscape contains many low-loss train minima that don't generalize — the model can memorize specific surface forms (e.g., the literal string he goes) instead of learning the conjugation rule (-s on present-simple 3rd-sg of go).

Weight decay with \(\lambda = 0.1\) exerts a small, constant pull on every weight toward zero. Each step subtracts \(\eta_t \lambda \theta\) from \(\theta\). In the regime where task gradients on rare verbs are small, this pull is the dominant force for those parameters — and it prevents the embedding table from drifting into the high-norm memorization regime.

The pedagogical point: weight decay is not "anti-overfitting magic". It's a force that keeps the optimizer in the well-conditioned region where AdamW's moment estimates are calibrated to the task gradient scale, not to the parameter drift scale.

Diagnostic ladder Borja should walk

If Borja sees the val/train gap in Run B and wonders why:

  1. First check: the schedule, batch size, and seed are identical (they are). Eliminate "training got unlucky" as the explanation.
  2. Second check: the weight-norm panel. Run B's embedding norm rising 2.5× is the smoking gun. Run A's holding steady is the control.
  3. Third check: the per-slice eval (Phase 20). Run B's train accuracy on the 12 regular verbs is ~99%, on the 8 irregular verbs is ~95%. Val accuracy is 78% and 62% respectively. Run A's: train 92% / 88%, val 84% / 75%. Run B memorized; Run A learned.
  4. Confirm by ablation: set weight_decay=0.5 (the over-corrected case) and observe the opposite failure — both train and val loss plateau higher, weight norms collapse toward zero. This shows the regime is bounded on both sides; the sweet spot \(\lambda = 0.1\) is not arbitrary.

Reproducer

# Control
seed=42 weight_decay=0.1 just phase-18-train

# Break
seed=42 weight_decay=0.0 just phase-18-train

# Compare
just phase-18-compare experiments/18-control experiments/18-break-wd0

The compare script produces dashboard-compare.html overlaying the two runs.

Hint cascade (if Borja gets stuck)

  1. (Mild) "The two runs differ in one optimizer hyperparameter. Print both configs side by side."
  2. (Medium) "Look at the weight-norm panel. What is climbing in Run B that isn't climbing in Run A?"
  3. (Direct) "AdamW's weight_decay term subtracts \(\eta_t \lambda \theta\) each step. What's the implication if \(\lambda = 0\)?"

Fix

Restore weight_decay: 0.1 in Run B's config. Re-run. Confirm Run B's curves now match Run A.

What this break is NOT

  • Not a numerical-precision break (no fp16, no NaN).
  • Not a data-leak break (val and train splits are clean).
  • Not a model-capacity break (architecture is unchanged).

It is a regularization-removed break, and the failure surfaces only after enough steps for the embedding table to drift. That delay (~600 steps before divergence) is itself a lesson: regularization failures are slow, not catastrophic.

Cross-refs

  • theory/05-adamw-vs-adam-decoupling.md — the algebraic reason.
  • Phase 19 theory/03-three-failure-modes.md — sibling failures (init, warmup, mask).
  • Phase 20 theory/01-metrics-catalog.md — per-slice accuracy is how the memorization is quantified, not just visualized.