Skip to content

English · Español

05 — AdamW vs Adam: la matemática exacta del desacoplo, a escala §A13

🇪🇸 AdamW no es "Adam con weight decay activado". Es Adam con weight decay desacoplado del gradient. La diferencia es una sola línea de álgebra, pero a la escala microscópica del corpus §A13 (~600 formas, ~103k parámetros) decide si tu modelo memoriza o generaliza al sexto verbo irregular.

Este archivo es el complemento en profundidad de theory/02-optimizer-and-schedule.md. Reformulamos las dos reglas de actualización en paralelo, mostramos el paso algebraico que convierte una en la otra y, después, recorremos un ejemplo numérico lo bastante pequeño como para hacerlo a mano — y explicamos por qué la diferencia importa más para nuestro corpus §A13 que para GPT-2.


Las dos actualizaciones, en paralelo

Adam original con regularización \(L_2\) añade \(\lambda \theta\) al gradient antes de la actualización de momentos. Escribiendo \(\tilde g_t = g_t + \lambda \theta_{t-1}\):

\[ \text{Adam-L2:} \quad m_t = \beta_1 m_{t-1} + (1 - \beta_1) \tilde g_t, \quad v_t = \beta_2 v_{t-1} + (1 - \beta_2) \tilde g_t^{\,2}, \quad \theta_t = \theta_{t-1} - \eta_t \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon} \]

AdamW calcula los momentos solo sobre el gradient de la tarea, y después aplica el weight decay directamente sobre el parámetro como parte de la actualización:

\[ \text{AdamW:} \quad m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t, \quad v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^{\,2}, \quad \theta_t = \theta_{t-1} - \eta_t \left( \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon} + \lambda \theta_{t-1} \right) \]

La línea que cambia: qué entra en \(m_t, v_t\). AdamW mantiene \(g_t\) limpio; Adam-L2 lo contamina.

El paso algebraico que muestra la diferencia

Desarrolla la actualización de AdamW en la convergencia — supón \(g_t \approx 0\) y que los momentos se han decaído de forma que \(\hat m_t / (\sqrt{\hat v_t} + \epsilon) \approx 0\). Entonces:

\[ \theta_t \approx \theta_{t-1} - \eta_t \lambda \theta_{t-1} = (1 - \eta_t \lambda) \theta_{t-1} \]

Decaimiento geométrico limpio hacia cero a una tasa de \(\eta_t \lambda\) por step. Este es el comportamiento previsto del weight decay.

Ahora desarrolla Adam-L2 bajo la misma suposición. El término \(\lambda \theta_{t-1}\) vive dentro de los momentos, así que incluso cuando \(g_t = 0\):

\[ \hat m_t \approx \lambda \theta_{t-1}, \quad \hat v_t \approx \lambda^2 \theta_{t-1}^{\,2}, \quad \theta_t \approx \theta_{t-1} - \eta_t \frac{\lambda \theta_{t-1}}{\sqrt{\lambda^2 \theta_{t-1}^{\,2}} + \epsilon} \approx \theta_{t-1} - \eta_t \, \text{sign}(\theta_{t-1}) \]

El decay ya no es proporcional a \(\theta_{t-1}\) — es un step de magnitud constante \(\eta_t\) hacia cero, escalado por el signo. Los parámetros con \(|\theta_{t-1}| < \eta_t\) quedan anulados en un solo step; los parámetros con \(|\theta_{t-1}| \gg \eta_t\) apenas se decaen. El regularizador efectivo está más cerca de \(L_1\) que de \(L_2\), y depende de la normalización adaptativa, lo que significa que la tasa de decay por parámetro está ahora acoplada al historial del gradient. Loshchilov & Hutter (2019) mostraron que esta es la razón por la que Adam-L2 rinde peor que SGD-L2 en clasificación de imágenes; el mismo efecto aparece en modelado de lenguaje.

Ejemplo numérico, a mano

Toma un único parámetro \(\theta_{t-1} = 0.4\). Fija \(\eta_t = 3 \times 10^{-4}\), \(\lambda = 0.1\), \(\beta_1 = 0.9\), \(\beta_2 = 0.95\), \(\epsilon = 10^{-8}\), \(t = 100\) (las correcciones de sesgo son esencialmente 1). Supón \(g_t = 0.01\) (gradient pequeño en fase tardía) y \(m_{t-1} = v_{t-1} = 0\) para mayor claridad.

Step de AdamW:

  1. \(m_t = 0.1 \cdot 0.01 = 10^{-3}\)
  2. \(v_t = 0.05 \cdot 10^{-4} = 5 \times 10^{-6}\)
  3. \(\hat m_t \approx 10^{-3}\), \(\hat v_t \approx 5 \times 10^{-6}\), \(\sqrt{\hat v_t} \approx 2.24 \times 10^{-3}\)
  4. Actualización de AdamW sobre \(\theta\): \(0.4 - 3 \times 10^{-4} \cdot (10^{-3} / 2.24 \times 10^{-3} + 0.1 \cdot 0.4) = 0.4 - 3 \times 10^{-4} \cdot (0.446 + 0.04) = 0.4 - 1.46 \times 10^{-4} \approx 0.39985\)

El movimiento 0.4 → 0.39985 tiene dos partes: el término de la tarea aporta \(\approx 1.34 \times 10^{-4}\), el término de decay aporta \(\approx 1.2 \times 10^{-5}\). El decay es pequeño pero proporcional a \(\theta\).

Step de Adam-L2:

  1. \(\tilde g_t = 0.01 + 0.1 \cdot 0.4 = 0.05\) (¡el decay es 4× mayor que el gradient de la tarea!)
  2. \(m_t = 0.1 \cdot 0.05 = 5 \times 10^{-3}\)
  3. \(v_t = 0.05 \cdot 2.5 \times 10^{-3} = 1.25 \times 10^{-4}\)
  4. \(\sqrt{\hat v_t} \approx 1.12 \times 10^{-2}\)
  5. Actualización sobre \(\theta\): \(0.4 - 3 \times 10^{-4} \cdot (5 \times 10^{-3} / 1.12 \times 10^{-2}) = 0.4 - 3 \times 10^{-4} \cdot 0.446 = 0.39987\)

Los dos resultados parecen similares (0.39985 vs 0.39987) — y ese es todo el problema. La trayectoria de Adam-L2 puede aproximar a AdamW en steps individuales, pero las estimaciones de los momentos se han corrompido: \(v_t\) es ahora 25× mayor de lo que debería (porque \(\tilde g_t\) quedó inflado por el término de decay). En los siguientes 100 steps, cada gradient de la tarea se divide entre un \(\sqrt{\hat v_t}\) que sobreponderá el historial de decay. El modelo entrena efectivamente con un learning rate efectivo menor que el que dice el schedule.

A escala §A13, donde el gradient típico sobre el corpus de 600 formas es \(\sim 10^{-2}\) y \(\lambda \theta \sim 10^{-2}\) en un peso saludable (\(\theta \sim 0.1\)), los dos términos son comparables. La corrupción por decay no es una perturbación minúscula; desplaza \(v_t\) en 2–10×, dependiendo de la magnitud del peso. Eso es un optimizer fundamentalmente diferente.

Por qué esto importa más a escala §A13 que a escala GPT-2

Los modelos de clase GPT-2 tienen \(d_\text{model} \approx 768\), vocab \(\approx 50k\), \(\lambda \approx 0.1\), pesos inicializados en \(\sim 0.02\). Los gradients de la tarea durante el centro denso del entrenamiento son \(\sim 10^{-3}\); \(\lambda \theta \sim 2 \times 10^{-3}\). El ratio decay-a-gradient es \(\sim 2\). Adam-L2 frente a AdamW aparece en el segundo decimal de la PPL de validación — medible, no decisivo.

Nuestro corpus §A13 tiene 600 formas, vocab \(\sim 512\), \(d_\text{model} = 64\). La señal-a-decay es distinta: los pesos de la tabla de embeddings para verbos poco frecuentes (p. ej. write) reciben \(\sim 6\) actualizaciones de gradient por época en promedio (porque la palabra aparece en \(\sim 1\%\) del corpus). Su norma de gradient es pequeña porque el verbo es raro, no porque el modelo haya convergido. Si Adam-L2 infla el término de decay hasta que domina a \(g_t\), la estimación del momento se convierte en "el decay es la señal" — y el embedding del verbo raro empieza a moverse hacia cero más rápido de lo que la señal de la tarea puede arrastrarlo de vuelta. Este es el modo de fallo §A13 que AdamW evita.

Por eso también usamos \(\beta_2 = 0.95\) en lugar de \(0.999\) a esta escala (ver theory/02-optimizer-and-schedule.md): no podemos permitirnos promediar gradients durante cientos de steps, porque el gradient del verbo raro es la señal que queremos amplificar.

Trampas de implementación con las que vas a tropezar

  1. Excluir biases y LayerNorms del decay. Práctica estándar: el weight decay se aplica a tensores 2-D+ (los "pesos" reales), no a tensores 1-D de bias / escala de LN. En código, eso es una división de param_group. Si lo olvidas, los biases derivan hacia cero y se reduce la capacidad expresiva de la red. El loop.py de referencia de la Fase 18 partirá el optimizer state en dos grupos.
  2. Decay aplicado a los embeddings. Algunas recetas excluyen los embeddings del decay; otras los incluyen. A escala §A13, inclúyelos — la tabla de embeddings es la mitad del conteo de parámetros, y no decaerla le da una libertad desproporcionada para memorizar el train set. El lab de la Fase 19 barrerá esto y lo confirmará.
  3. La trampa "Adam con weight_decay=" de PyTorch. torch.optim.Adam(..., weight_decay=0.1) hace Adam-L2, no AdamW. Para obtener AdamW, usa torch.optim.AdamW. Esto pilla a mucha gente. La Fase 25 diseccionará el dispatcher y verás que las dos son operaciones genuinamente diferentes.

Cita

Loshchilov, I., & Hutter, F. (2019). Decoupled Weight Decay Regularization. ICLR 2019. https://arxiv.org/abs/1711.05101 — Secciones 2 (el desacoplo algebraico) y 4.2 (los experimentos con modelos pequeños más análogos a la escala §A13).

Resumen en un párrafo

Adam-L2 mete \(\lambda \theta_{t-1}\) dentro del gradient antes de calcular \(m, v\), contaminando ambos momentos con el término de decay. AdamW calcula \(m, v\) solo sobre el gradient de la tarea, y después aplica el decay \(\lambda \theta_{t-1}\) como término separado en la actualización del parámetro. En convergencia, AdamW recupera el decaimiento geométrico limpio \(\theta \to (1 - \eta \lambda) \theta\); Adam-L2 degenera en un step basado en signo cuya magnitud es independiente de \(\theta\). A escala §A13, donde \(\lambda \theta\) y \(g_t\) están dentro de un orden de magnitud sobre embeddings de verbos raros, la diferencia no es cosmética — determina si la señal del verbo raro sobrevive a la presión del decay.


Referencias cruzadas: theory/02-optimizer-and-schedule.md (la receta), theory/03-mixed-precision-preview.md (cómo interactúa el decay con el loss scaling), Fase 19 lab/02-break-it.md (una de las roturas intencionales es exactamente Adam-L2 sustituido por AdamW).