Skip to content

English · Español

05 — AdamW vs Adam: the exact decoupling math, at §A13 scale

🇪🇸 AdamW no es "Adam con weight decay activado". Es Adam con weight decay desacoplado del gradiente. La diferencia es una sola línea de álgebra, pero a la escala microscópica del corpus §A13 (~600 formas, ~103k parámetros) decide si tu modelo memoriza o generaliza al sexto verbo irregular.

This file is the depth-pass companion to theory/02-optimizer-and-schedule.md. We restate the two update rules side-by-side, show the algebraic step that turns one into the other, then walk a numerical example small enough to do by hand — and explain why the difference matters more for our §A13 corpus than it does for GPT-2.


The two updates, side by side

Vanilla Adam with \(L_2\) regularization adds \(\lambda \theta\) to the gradient before the moment update. Writing \(\tilde g_t = g_t + \lambda \theta_{t-1}\):

\[ \text{Adam-L2:} \quad m_t = \beta_1 m_{t-1} + (1 - \beta_1) \tilde g_t, \quad v_t = \beta_2 v_{t-1} + (1 - \beta_2) \tilde g_t^{\,2}, \quad \theta_t = \theta_{t-1} - \eta_t \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon} \]

AdamW computes moments on the task gradient only, then applies weight decay directly on the parameter as part of the update:

\[ \text{AdamW:} \quad m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t, \quad v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^{\,2}, \quad \theta_t = \theta_{t-1} - \eta_t \left( \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon} + \lambda \theta_{t-1} \right) \]

The line that changes: what enters \(m_t, v_t\). AdamW keeps \(g_t\) clean; Adam-L2 contaminates it.

The algebraic step that shows the difference

Expand the AdamW update at convergence — assume \(g_t \approx 0\) and the moments have decayed so \(\hat m_t / (\sqrt{\hat v_t} + \epsilon) \approx 0\). Then:

\[ \theta_t \approx \theta_{t-1} - \eta_t \lambda \theta_{t-1} = (1 - \eta_t \lambda) \theta_{t-1} \]

Clean geometric decay toward zero at rate \(\eta_t \lambda\) per step. This is the intended behavior of weight decay.

Now expand Adam-L2 under the same assumption. The \(\lambda \theta_{t-1}\) term lives inside the moments, so even when \(g_t = 0\):

\[ \hat m_t \approx \lambda \theta_{t-1}, \quad \hat v_t \approx \lambda^2 \theta_{t-1}^{\,2}, \quad \theta_t \approx \theta_{t-1} - \eta_t \frac{\lambda \theta_{t-1}}{\sqrt{\lambda^2 \theta_{t-1}^{\,2}} + \epsilon} \approx \theta_{t-1} - \eta_t \, \text{sign}(\theta_{t-1}) \]

The decay is no longer proportional to \(\theta_{t-1}\) — it's a constant-magnitude \(\eta_t\) step toward zero, scaled by the sign. Parameters with \(|\theta_{t-1}| < \eta_t\) are zeroed out in one step; parameters with \(|\theta_{t-1}| \gg \eta_t\) are barely decayed. The effective regularizer is closer to \(L_1\) than \(L_2\), and it depends on the adaptive normalization, which means the per-parameter decay rate is now coupled to the gradient history. Loshchilov & Hutter (2019) showed this is why Adam-L2 underperforms SGD-L2 on image classification; the same effect appears in language modeling.

Numerical example, by hand

Take a single parameter \(\theta_{t-1} = 0.4\). Set \(\eta_t = 3 \times 10^{-4}\), \(\lambda = 0.1\), \(\beta_1 = 0.9\), \(\beta_2 = 0.95\), \(\epsilon = 10^{-8}\), \(t = 100\) (bias corrections are essentially 1). Suppose \(g_t = 0.01\) (small late-stage gradient) and \(m_{t-1} = v_{t-1} = 0\) for clarity.

AdamW step:

  1. \(m_t = 0.1 \cdot 0.01 = 10^{-3}\)
  2. \(v_t = 0.05 \cdot 10^{-4} = 5 \times 10^{-6}\)
  3. \(\hat m_t \approx 10^{-3}\), \(\hat v_t \approx 5 \times 10^{-6}\), \(\sqrt{\hat v_t} \approx 2.24 \times 10^{-3}\)
  4. AdamW update on \(\theta\): \(0.4 - 3 \times 10^{-4} \cdot (10^{-3} / 2.24 \times 10^{-3} + 0.1 \cdot 0.4) = 0.4 - 3 \times 10^{-4} \cdot (0.446 + 0.04) = 0.4 - 1.46 \times 10^{-4} \approx 0.39985\)

The 0.4 → 0.39985 movement has two parts: the task term contributes \(\approx 1.34 \times 10^{-4}\), the decay term contributes \(\approx 1.2 \times 10^{-5}\). Decay is small but proportional to \(\theta\).

Adam-L2 step:

  1. \(\tilde g_t = 0.01 + 0.1 \cdot 0.4 = 0.05\) (decay is 4× larger than the task gradient!)
  2. \(m_t = 0.1 \cdot 0.05 = 5 \times 10^{-3}\)
  3. \(v_t = 0.05 \cdot 2.5 \times 10^{-3} = 1.25 \times 10^{-4}\)
  4. \(\sqrt{\hat v_t} \approx 1.12 \times 10^{-2}\)
  5. Update on \(\theta\): \(0.4 - 3 \times 10^{-4} \cdot (5 \times 10^{-3} / 1.12 \times 10^{-2}) = 0.4 - 3 \times 10^{-4} \cdot 0.446 = 0.39987\)

The two outputs look similar (0.39985 vs 0.39987) — and that's the whole problem. The Adam-L2 path can approximate AdamW on individual steps, but the moment estimates have been corrupted: \(v_t\) is now 25× larger than it should be (because \(\tilde g_t\) was inflated by the decay term). On the next 100 steps, every task gradient is divided by \(\sqrt{\hat v_t}\) that overweights the decay history. The model effectively trains with a smaller effective learning rate than the schedule says.

At §A13 scale, where the typical gradient on the 600-form corpus is \(\sim 10^{-2}\) and \(\lambda \theta \sim 10^{-2}\) on a healthy weight (\(\theta \sim 0.1\)), the two terms are comparable. Decay corruption is not a tiny perturbation; it shifts \(v_t\) by 2–10×, depending on the weight's magnitude. That's a fundamentally different optimizer.

Why this matters more at §A13 scale than at GPT-2 scale

GPT-2-class models have \(d_\text{model} \approx 768\), vocab \(\approx 50k\), \(\lambda \approx 0.1\), weights initialized at \(\sim 0.02\). Task gradients during the meaty middle of training are \(\sim 10^{-3}\); \(\lambda \theta \sim 2 \times 10^{-3}\). The decay-to-gradient ratio is \(\sim 2\). Adam-L2 vs AdamW shows up at the second decimal of validation PPL — measurable, not life-changing.

Our §A13 corpus has 600 forms, vocab \(\sim 512\), \(d_\text{model} = 64\). The signal-to-decay ratio is different: weights of the embedding table for rare verbs (e.g., write) see \(\sim 6\) gradient updates per epoch on average (because the word appears in \(\sim 1\%\) of the corpus). Their gradient norm is small because the verb is rare, not because the model converged. If Adam-L2 inflates the decay term so it dominates \(g_t\), the moment estimate becomes "decay is the signal" — and the rare-verb embedding starts moving toward zero faster than the task signal can pull it back. This is the §A13 failure mode AdamW prevents.

This is also why we use \(\beta_2 = 0.95\) instead of \(0.999\) at this scale (see theory/02-optimizer-and-schedule.md): we cannot afford to average gradients over hundreds of steps, because the rare-verb gradient is the signal we want to amplify.

Implementation gotchas you will hit

  1. Excluding biases and LayerNorms from decay. Standard practice: weight decay applies to 2-D+ tensors (the actual "weights"), not 1-D bias / LN-scale tensors. In code, that's a param_group split. If you forget, biases drift to zero and the network's expressive capacity shrinks. Phase 18's reference loop.py will partition the optimizer state into two groups.
  2. Decay applied to embeddings. Some recipes exclude embeddings from decay; others include them. At §A13 scale, include them — the embedding table is half the parameter count, and not decaying it gives it disproportionate freedom to memorize the train set. Phase 19 lab will sweep this and confirm.
  3. The "Adam with weight_decay=" PyTorch trap. torch.optim.Adam(..., weight_decay=0.1) does Adam-L2, not AdamW. To get AdamW, use torch.optim.AdamW. This catches people. Phase 25 will dissect the dispatcher and you'll see the two are genuinely different ops.

Citation

Loshchilov, I., & Hutter, F. (2019). Decoupled Weight Decay Regularization. ICLR 2019. https://arxiv.org/abs/1711.05101 — Sections 2 (the algebraic decoupling) and 4.2 (the small-model experiments most analogous to §A13 scale).

One-paragraph recap

Adam-L2 folds \(\lambda \theta_{t-1}\) into the gradient before computing \(m, v\), contaminating both moments with the decay term. AdamW computes \(m, v\) on the task gradient only, then applies the \(\lambda \theta_{t-1}\) decay as a separate term in the parameter update. At convergence, AdamW recovers the clean geometric decay \(\theta \to (1 - \eta \lambda) \theta\); Adam-L2 degenerates to a sign-based step whose magnitude is independent of \(\theta\). At §A13 scale, where \(\lambda \theta\) and \(g_t\) are within an order of magnitude on rare-verb embeddings, the difference is not cosmetic — it determines whether the rare-verb signal survives the decay pressure.


Cross-refs: theory/02-optimizer-and-schedule.md (the recipe), theory/03-mixed-precision-preview.md (how decay interacts with loss scaling), Phase 19 lab/02-break-it.md (one engineered break is exactly Adam-L2 swapped in for AdamW).