Skip to content

English · Español

02 — AdamW + warmup + cosine decay + gradient clipping

Cuatro piezas que parecen detalles y deciden si tu loss curve es una pendiente suave o una sierra. Aquí derivamos cada una desde la Fase 4, las ensamblamos, y vemos por qué la receta moderna funciona en este orden y no en otro.


La Fase 4 derivó las matemáticas del optimizador desde cero. La Fase 9 implementó SGD y Adam en minitorch/optim.py. Este fichero es la referencia de implementación: reformula las ecuaciones exactamente en la forma en que Borja las tecleará en src/minitrain/loop.py, en el orden exacto en que se aplican, con cada variable nombrada como la nombra el código.

AdamW — las ecuaciones tal y como las escribes

Para cada tensor de parámetro \(\theta\) con gradiente \(g_t\) en el paso \(t\):

\[ m_t \leftarrow \beta_1 \, m_{t-1} + (1 - \beta_1) \, g_t \]
\[ v_t \leftarrow \beta_2 \, v_{t-1} + (1 - \beta_2) \, g_t^2 \]
\[ \hat m_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat v_t = \frac{v_t}{1 - \beta_2^t} \]
\[ \theta_t \leftarrow \theta_{t-1} - \eta_t \left( \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon} + \lambda \, \theta_{t-1} \right) \]

Hiperparámetros por defecto (recomendados para la Fase 18 salvo que tus runs digan lo contrario):

  • \(\beta_1 = 0.9\) — decaimiento del primer momento
  • \(\beta_2 = 0.95\) — decaimiento del segundo momento (elección moderna; el Adam original usaba 0.999, el entrenamiento moderno de LLM usa 0.95)
  • \(\epsilon = 10^{-8}\)
  • \(\lambda = 0.1\) — coeficiente de weight decay
  • \(\eta_{\max} = 3 \times 10^{-4}\) — LR pico (tras warmup)
  • \(\eta_{\min} = 3 \times 10^{-5}\) — LR mínima al final del cosine decay

Tres cosas que parecen detalles pero no lo son

1. \(\beta_2 = 0.95\) vs \(0.999\). Con un corpus de 600 formas y ~103k parámetros, la estimación del segundo momento \(v_t\) converge rápido. \(\beta_2 = 0.999\) tarda ~1000 pasos en "ver" la magnitud del gradiente de cualquier parámetro dado. \(\beta_2 = 0.95\) tarda ~20. No tenemos 1000 pasos de sobra — el run de entrenamiento entero es del orden de unos pocos miles de pasos.

2. La corrección de sesgo \(\hat m_t = m_t / (1 - \beta_1^t)\) no es opcional. En el paso \(t=1\), \(m_1 = (1 - \beta_1) g_1 = 0.1 g_1\). Sin corrección de sesgo, los primeros 100 pasos actualizan \(\theta\) ~10× menos de lo que deberían. El warmup cosine lo enmascara algo, pero el optimizador debe seguir corrigiendo el sesgo internamente. Bug común: implementar la actualización con \(m_t\) en lugar de \(\hat m_t\) y sorprenderse de que "el warmup es demasiado agresivo". No es el warmup — es el optimizador que nunca se calienta.

3. \(\lambda \theta_{t-1}\) es weight decay desacoplado. La "W" en AdamW (vs Adam vanilla) es el desacople: el término de weight decay se añade a la actualización, no al gradiente. Acoplar el weight decay al gradiente (g_t += λ θ_{t-1}) hace que AdamW colapse a Adam-con-L2-reg y rompe la geometría — véase Loshchilov & Hutter (2019). El theory/04-optimizers.md de la Fase 4 ya derivó esto; si está borroso, relee ese fichero.

Schedule cosine con warmup lineal

Dos regímenes:

  • Warmup durante los primeros \(W\) pasos: $\(\eta_t = \eta_{\max} \cdot \frac{t}{W}, \quad t \in [0, W)\)$
  • Cosine decay para \(t \in [W, T]\): $\(\eta_t = \eta_{\min} + \tfrac{1}{2} (\eta_{\max} - \eta_{\min}) \left( 1 + \cos\frac{\pi (t - W)}{T - W} \right)\)$

Donde \(T\) son los pasos totales de entrenamiento.

Defaults: - \(W = 100\) (aproximadamente 5% de los pasos totales) - \(T = 2000\) (aproximadamente 50 epochs sobre el train set de 240 formas con batch size ~6)

Por qué el warmup es no-opcional para transformers

En el paso 0, el modelo está inicializado aleatoriamente. La pérdida es alta (~\(\ln V\) donde \(V\) es el tamaño del vocabulario — para \(V = 512\), eso es ~6.2). El gradiente es grande y mal condicionado: el Hessiano está lejos del cuadrático local, así que un paso de LR normal sobrepasa salvajemente. Sin warmup:

  1. Paso 1: los pesos se empujan en alguna dirección con magnitud \(\eta_{\max} \cdot \|g_1\|\).
  2. Paso 2: los gradientes explotan porque el modelo está ahora lejos de cualquier sitio razonable.
  3. NaN para el paso 50.

Este modo de fallo es el bug #2 de los tres breaks intencionados de la Fase 19. Lo verás en el dashboard. El warmup lo elimina rampando \(\eta\) linealmente de 0 a \(\eta_{\max}\) a lo largo de \(W\) pasos, dando al optimizador tiempo para estimar \(v_t\) (la escala por parámetro) antes de dar pasos de tamaño completo.

Por qué cosine específicamente

Tres schedules alternativos: - Constante (sin decaimiento): puede igualar a cosine en runs cortos pero pierde 1-3% de PPL en runs largos (la LR es "demasiado alta" cerca del final, impidiendo convergencia fina). - Decaimiento lineal: coincide con cosine dentro del 1% pero la LR cae demasiado rápido cerca del final. - Decaimiento por pasos: las discontinuidades en la LR causan picos de pérdida en las transiciones.

Cosine combina decaimiento suave (sin picos) con una cola lenta (LR pequeña durante muchos pasos tardíos, permitiendo convergencia fina). No es magia — es una forma de curva razonable. La Fase 4 las dibujó todas.

Gradient clipping

Tras calcular los gradientes, antes del paso del optimizador, clip la norma L2 global:

\[ \|g\|_2 = \sqrt{\sum_{\text{all params}} \|g_\theta\|_F^2} \]

Si \(\|g\|_2 > c\) (donde \(c\) es el umbral de clip, default \(c = 1.0\)):

\[ g \leftarrow g \cdot \frac{c}{\|g\|_2} \]

Esto reescala todos los tensores de gradiente uniformemente. El clipping por tensor está mal: cambia la dirección de la actualización entre parámetros, no solo la magnitud. El clipping de norma global preserva la dirección.

¿Por qué clipear?

Dos razones:

  1. Defiende contra batches outlier raros. La mayoría de batches tienen \(\|g\|_2 < 1\). Ocasionalmente un batch con una predicción muy confiadamente errónea produce \(\|g\|_2 \approx 50\). Ese único paso desestabiliza el optimizador (los momentos ahora creen que el gradiente típico es 50× mayor de lo que es, y los pasos futuros se quedan sin tamaño). El clipping impide que un mal batch envenene las estimaciones de momentos.
  2. Seguro barato. \(c = 1.0\) rara vez se excede en entrenamiento sano. Cuando se hace, quieres saberlo — registra \(\|g\|_2\) cada paso y vigila los picos. El dashboard de la Fase 19 lo plotea.

El umbral de clip \(c\) es un hiperparámetro, pero \(c = 1.0\) es casi siempre adecuado. Poner \(c < 0.1\) silenciosamente estrangula el entrenamiento; \(c > 10\) no clipea de verdad.

Ensamblándolo: el paso del optimizador

def step(self, params, grads):
    self.t += 1
    g_norm_sq = sum((g * g).sum() for g in grads.values())
    g_norm = np.sqrt(g_norm_sq)

    # 1. clip
    clip_factor = min(1.0, self.clip / (g_norm + 1e-12))

    # 2. learning rate
    if self.t < self.warmup:
        lr = self.lr_max * self.t / self.warmup
    else:
        progress = (self.t - self.warmup) / (self.total - self.warmup)
        lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + math.cos(math.pi * progress))

    # 3. AdamW update per parameter
    for name, p in params.items():
        g = grads[name] * clip_factor
        self.m[name] = self.beta1 * self.m[name] + (1 - self.beta1) * g
        self.v[name] = self.beta2 * self.v[name] + (1 - self.beta2) * (g * g)
        m_hat = self.m[name] / (1 - self.beta1 ** self.t)
        v_hat = self.v[name] / (1 - self.beta2 ** self.t)
        p -= lr * (m_hat / (np.sqrt(v_hat) + self.eps) + self.weight_decay * p)

El orden es: norma → clip → schedule → actualización de momento → corrección de sesgo → paso. Equivoca el orden y verás uno de:

  • Clipping tras la actualización AdamW: los momentos siguen viendo el gradiente sin clipear, así que un batch futuro queda desestabilizado.
  • Corrección de sesgo saltada o aplicada a \(m\) pero no a \(v\): warmup asimétrico que sesga las actualizaciones tempranas.
  • Weight decay aplicado a los gradientes en lugar de a la actualización: AdamW colapsa a Adam-con-L2.

Problemas de drill

  1. AdamW con \(\beta_1 = 0.9\), \(\beta_2 = 0.95\). En el paso \(t = 10\), los factores de corrección de sesgo son \((1 - 0.9^{10}) \approx 0.65\) y \((1 - 0.95^{10}) \approx 0.40\). ¿Qué fracción del "verdadero" primer momento hay en \(m_{10}\)? ¿Y en \(v_{10}\)? ¿Por qué son tan distintos?
  2. El warmup es \(W = 100\) y \(\eta_{\max} = 3 \times 10^{-4}\). En el paso 25, ¿cuánto es \(\eta_{25}\)?
  3. El entrenamiento completo es \(T = 2000\) pasos con \(W = 100\). En el paso 1500, ¿cuánto es \(\eta_{1500}\)? (El cosine progresa en \((1500 - 100) / (2000 - 100) = 0.737\) a través del decay; \(\cos(\pi \cdot 0.737) \approx -0.69\).)
  4. La norma global del gradiente en el paso 50 es 12.0, el umbral de clip es 1.0. El tensor de gradiente para la MLP de la capa 3 tiene norma de Frobenius 4.0 antes del clipping. ¿Cuál es su norma después?

Si los cuatro están claros, sigue.

Recap de un párrafo

AdamW + warmup lineal + cosine decay + clipping de norma global es la receta moderna. AdamW difiere de Adam en que desacopla el weight decay hacia la actualización, no hacia el gradiente. El warmup rampa linealmente \(\eta\) de 0 a \(\eta_{\max}\) durante los primeros \(W\) pasos para que el optimizador pueda estimar \(v_t\) antes de dar actualizaciones de tamaño completo. El cosine decay baja \(\eta\) suavemente a \(\eta_{\min}\) durante los pasos restantes, permitiendo convergencia fina al final. El clipping de norma L2 global con \(c = 1.0\) impide que un único mal batch envenene las estimaciones de momentos. El orden de implementación es norma → clip → schedule → actualización de momento → corrección de sesgo → paso, y equivocarlo rompe silenciosamente una de las cuatro piezas.

Lo que esta sección NO cubre

  • EMA (media móvil exponencial de pesos). Stubbed en la Fase 18, implementación real en la Fase 26+.
  • LR por capas / LR por grupo de parámetros. La Fase 28 (LoRA) las usa.
  • Lookahead / Lion / otros optimizadores modernos. Fuera de alcance.
  • Scheduling atado al loss-scale para fp16. Fase 26.

Siguiente: theory/03-mixed-precision-preview.md.