English · Español
05 — AdamW vs Adam: the exact decoupling math, at §A13 scale¶
🇪🇸 AdamW no es "Adam con weight decay activado". Es Adam con weight decay desacoplado del gradiente. La diferencia es una sola línea de álgebra, pero a la escala microscópica del corpus §A13 (~600 formas, ~103k parámetros) decide si tu modelo memoriza o generaliza al sexto verbo irregular.
This file is the depth-pass companion to theory/02-optimizer-and-schedule.md. We restate the two update rules side-by-side, show the algebraic step that turns one into the other, then walk a numerical example small enough to do by hand — and explain why the difference matters more for our §A13 corpus than it does for GPT-2.
The two updates, side by side¶
Vanilla Adam with \(L_2\) regularization adds \(\lambda \theta\) to the gradient before the moment update. Writing \(\tilde g_t = g_t + \lambda \theta_{t-1}\):
AdamW computes moments on the task gradient only, then applies weight decay directly on the parameter as part of the update:
The line that changes: what enters \(m_t, v_t\). AdamW keeps \(g_t\) clean; Adam-L2 contaminates it.
The algebraic step that shows the difference¶
Expand the AdamW update at convergence — assume \(g_t \approx 0\) and the moments have decayed so \(\hat m_t / (\sqrt{\hat v_t} + \epsilon) \approx 0\). Then:
Clean geometric decay toward zero at rate \(\eta_t \lambda\) per step. This is the intended behavior of weight decay.
Now expand Adam-L2 under the same assumption. The \(\lambda \theta_{t-1}\) term lives inside the moments, so even when \(g_t = 0\):
The decay is no longer proportional to \(\theta_{t-1}\) — it's a constant-magnitude \(\eta_t\) step toward zero, scaled by the sign. Parameters with \(|\theta_{t-1}| < \eta_t\) are zeroed out in one step; parameters with \(|\theta_{t-1}| \gg \eta_t\) are barely decayed. The effective regularizer is closer to \(L_1\) than \(L_2\), and it depends on the adaptive normalization, which means the per-parameter decay rate is now coupled to the gradient history. Loshchilov & Hutter (2019) showed this is why Adam-L2 underperforms SGD-L2 on image classification; the same effect appears in language modeling.
Numerical example, by hand¶
Take a single parameter \(\theta_{t-1} = 0.4\). Set \(\eta_t = 3 \times 10^{-4}\), \(\lambda = 0.1\), \(\beta_1 = 0.9\), \(\beta_2 = 0.95\), \(\epsilon = 10^{-8}\), \(t = 100\) (bias corrections are essentially 1). Suppose \(g_t = 0.01\) (small late-stage gradient) and \(m_{t-1} = v_{t-1} = 0\) for clarity.
AdamW step:
- \(m_t = 0.1 \cdot 0.01 = 10^{-3}\)
- \(v_t = 0.05 \cdot 10^{-4} = 5 \times 10^{-6}\)
- \(\hat m_t \approx 10^{-3}\), \(\hat v_t \approx 5 \times 10^{-6}\), \(\sqrt{\hat v_t} \approx 2.24 \times 10^{-3}\)
- AdamW update on \(\theta\): \(0.4 - 3 \times 10^{-4} \cdot (10^{-3} / 2.24 \times 10^{-3} + 0.1 \cdot 0.4) = 0.4 - 3 \times 10^{-4} \cdot (0.446 + 0.04) = 0.4 - 1.46 \times 10^{-4} \approx 0.39985\)
The 0.4 → 0.39985 movement has two parts: the task term contributes \(\approx 1.34 \times 10^{-4}\), the decay term contributes \(\approx 1.2 \times 10^{-5}\). Decay is small but proportional to \(\theta\).
Adam-L2 step:
- \(\tilde g_t = 0.01 + 0.1 \cdot 0.4 = 0.05\) (decay is 4× larger than the task gradient!)
- \(m_t = 0.1 \cdot 0.05 = 5 \times 10^{-3}\)
- \(v_t = 0.05 \cdot 2.5 \times 10^{-3} = 1.25 \times 10^{-4}\)
- \(\sqrt{\hat v_t} \approx 1.12 \times 10^{-2}\)
- Update on \(\theta\): \(0.4 - 3 \times 10^{-4} \cdot (5 \times 10^{-3} / 1.12 \times 10^{-2}) = 0.4 - 3 \times 10^{-4} \cdot 0.446 = 0.39987\)
The two outputs look similar (0.39985 vs 0.39987) — and that's the whole problem. The Adam-L2 path can approximate AdamW on individual steps, but the moment estimates have been corrupted: \(v_t\) is now 25× larger than it should be (because \(\tilde g_t\) was inflated by the decay term). On the next 100 steps, every task gradient is divided by \(\sqrt{\hat v_t}\) that overweights the decay history. The model effectively trains with a smaller effective learning rate than the schedule says.
At §A13 scale, where the typical gradient on the 600-form corpus is \(\sim 10^{-2}\) and \(\lambda \theta \sim 10^{-2}\) on a healthy weight (\(\theta \sim 0.1\)), the two terms are comparable. Decay corruption is not a tiny perturbation; it shifts \(v_t\) by 2–10×, depending on the weight's magnitude. That's a fundamentally different optimizer.
Why this matters more at §A13 scale than at GPT-2 scale¶
GPT-2-class models have \(d_\text{model} \approx 768\), vocab \(\approx 50k\), \(\lambda \approx 0.1\), weights initialized at \(\sim 0.02\). Task gradients during the meaty middle of training are \(\sim 10^{-3}\); \(\lambda \theta \sim 2 \times 10^{-3}\). The decay-to-gradient ratio is \(\sim 2\). Adam-L2 vs AdamW shows up at the second decimal of validation PPL — measurable, not life-changing.
Our §A13 corpus has 600 forms, vocab \(\sim 512\), \(d_\text{model} = 64\). The signal-to-decay ratio is different: weights of the embedding table for rare verbs (e.g., write) see \(\sim 6\) gradient updates per epoch on average (because the word appears in \(\sim 1\%\) of the corpus). Their gradient norm is small because the verb is rare, not because the model converged. If Adam-L2 inflates the decay term so it dominates \(g_t\), the moment estimate becomes "decay is the signal" — and the rare-verb embedding starts moving toward zero faster than the task signal can pull it back. This is the §A13 failure mode AdamW prevents.
This is also why we use \(\beta_2 = 0.95\) instead of \(0.999\) at this scale (see theory/02-optimizer-and-schedule.md): we cannot afford to average gradients over hundreds of steps, because the rare-verb gradient is the signal we want to amplify.
Implementation gotchas you will hit¶
- Excluding biases and LayerNorms from decay. Standard practice: weight decay applies to 2-D+ tensors (the actual "weights"), not 1-D bias / LN-scale tensors. In code, that's a
param_groupsplit. If you forget, biases drift to zero and the network's expressive capacity shrinks. Phase 18's referenceloop.pywill partition the optimizer state into two groups. - Decay applied to embeddings. Some recipes exclude embeddings from decay; others include them. At §A13 scale, include them — the embedding table is half the parameter count, and not decaying it gives it disproportionate freedom to memorize the train set. Phase 19 lab will sweep this and confirm.
- The "Adam with weight_decay=" PyTorch trap.
torch.optim.Adam(..., weight_decay=0.1)does Adam-L2, not AdamW. To get AdamW, usetorch.optim.AdamW. This catches people. Phase 25 will dissect the dispatcher and you'll see the two are genuinely different ops.
Citation¶
Loshchilov, I., & Hutter, F. (2019). Decoupled Weight Decay Regularization. ICLR 2019. https://arxiv.org/abs/1711.05101 — Sections 2 (the algebraic decoupling) and 4.2 (the small-model experiments most analogous to §A13 scale).
One-paragraph recap¶
Adam-L2 folds \(\lambda \theta_{t-1}\) into the gradient before computing \(m, v\), contaminating both moments with the decay term. AdamW computes \(m, v\) on the task gradient only, then applies the \(\lambda \theta_{t-1}\) decay as a separate term in the parameter update. At convergence, AdamW recovers the clean geometric decay \(\theta \to (1 - \eta \lambda) \theta\); Adam-L2 degenerates to a sign-based step whose magnitude is independent of \(\theta\). At §A13 scale, where \(\lambda \theta\) and \(g_t\) are within an order of magnitude on rare-verb embeddings, the difference is not cosmetic — it determines whether the rare-verb signal survives the decay pressure.
Cross-refs: theory/02-optimizer-and-schedule.md (the recipe), theory/03-mixed-precision-preview.md (how decay interacts with loss scaling), Phase 19 lab/02-break-it.md (one engineered break is exactly Adam-L2 swapped in for AdamW).