Skip to content

English · Español

03 — Optimizadores: SGD, momentum, Adam

🇪🇸 La Fase 4 derivó los optimizadores en abstracto, sobre un dict de arrays. La Fase 9 los reescribe sobre Parameters — la única diferencia es que ahora el "estado" del optimizador (los momentos para Adam, las velocidades para momentum) vive en un dict indexado por id(parameter), no por nombre. La matemática es idéntica. Lo nuevo es la interfaz step() + zero_grad() que se inserta en el bucle de entrenamiento sin que tengas que pensar.


La clase base Optimizer

class Optimizer:
    """Base class for all optimizers."""

    def __init__(self, params: Iterable[Parameter], lr: float) -> None:
        self.params: list[Parameter] = list(params)
        self.lr = lr
        # Per-parameter state, keyed by id(p). Subclasses populate this.
        self.state: dict[int, dict[str, np.ndarray]] = {id(p): {} for p in self.params}

    def step(self) -> None:
        raise NotImplementedError

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

Veinte líneas. Decisiones horneadas:

  • params se materializa como una lista una vez. Los iterables se agotan; la lista es el registro del optimizador de lo que posee. Documenta: los parámetros añadidos al modelo después de crear el optimizador NO son rastreados.
  • Estado indexado por id(p). Basado en identidad. El ciclo de vida del objeto Parameter y el ciclo de vida del optimizador están atados — si un Parameter es reemplazado (raro, pero posible), el estado del optimizador para el id antiguo se vuelve obsoleto. PyTorch lo maneja; aceptamos la limitación.
  • zero_grad vive también en el optimizador. Convención de PyTorch. optim.zero_grad() y model.zero_grad() hacen lo mismo (ambos recorren params poniendo grad = None).
  • Sin state_dict() para el optimizador todavía. La Fase 18 (bucle de entrenamiento) lo añade para entrenamiento reanudable.

SGD: lo más simple posible

SGD puro:

class SGD(Optimizer):
    def __init__(self, params, lr: float, momentum: float = 0.0) -> None:
        super().__init__(params, lr)
        self.momentum = momentum
        if momentum > 0:
            for p in self.params:
                self.state[id(p)]["velocity"] = np.zeros_like(p.data)

    def step(self) -> None:
        for p in self.params:
            if p.grad is None:
                continue  # parameter wasn't touched by this backward pass
            if self.momentum > 0:
                v = self.state[id(p)]["velocity"]
                v *= self.momentum
                v += p.grad
                p.data -= self.lr * v
            else:
                p.data -= self.lr * p.grad

Veinte líneas. Puntos clave:

  • Velocidad almacenada en self.state, no en el Parameter. Mantiene los Parameters limpios (solo data y grad).
  • Las actualizaciones in-place a p.data y v están bien. Ocurren fuera del grafo de autograd (entre forward passes). El DAG se recomputa en cada iteración.
  • if p.grad is None: continue. Algunos parámetros pueden no aparecer en la pérdida para un batch en particular (raro en MLPs pequeños; común en configuraciones mixture-of-experts de transformer). La Fase 9 los salta; la Fase 18 puede avisar.

Por qué esta regla de actualización

Para una pérdida L(θ), el gradiente ∇L es la dirección de ascenso más empinado. Queremos descender, así que tomamos θ ← θ - η · ∇L con tasa de aprendizaje η.

El momentum añade un suavizado del gradiente a lo largo del tiempo:

v_t = μ · v_{t-1} + ∇L
θ_t = θ_{t-1} - η · v_t

Equivalente a: en vez de dar el paso en la dirección del gradiente actual, da el paso en la dirección de una media móvil ponderada exponencialmente. Suaviza sobre gradientes ruidosos de minibatch; acelera la convergencia en direcciones de baja curvatura.

La Fase 4 derivó todo esto. La Fase 9 lo implementa.

Adam: estimación adaptativa de momentos

La fórmula completa (Fase 4 derivó; aquí solo la codificamos):

m_t = β₁ · m_{t-1} + (1 - β₁) · g_t           # first moment (running mean of gradients)
v_t = β₂ · v_{t-1} + (1 - β₂) · g_t²          # second moment (running variance, element-wise)
m̂_t = m_t / (1 - β₁^t)                        # bias correction
v̂_t = v_t / (1 - β₂^t)                        # bias correction
θ_t = θ_{t-1} - η · m̂_t / (√v̂_t + ε)

En código:

class Adam(Optimizer):
    def __init__(
        self,
        params,
        lr: float = 1e-3,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8,
    ) -> None:
        super().__init__(params, lr)
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.t = 0  # global step counter
        for p in self.params:
            self.state[id(p)]["m"] = np.zeros_like(p.data)
            self.state[id(p)]["v"] = np.zeros_like(p.data)

    def step(self) -> None:
        self.t += 1
        for p in self.params:
            if p.grad is None:
                continue
            s = self.state[id(p)]
            g = p.grad
            s["m"] = self.beta1 * s["m"] + (1 - self.beta1) * g
            s["v"] = self.beta2 * s["v"] + (1 - self.beta2) * g * g
            m_hat = s["m"] / (1 - self.beta1 ** self.t)
            v_hat = s["v"] / (1 - self.beta2 ** self.t)
            p.data -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps)

Treinta líneas. Notas:

  • betas = (0.9, 0.999) y eps = 1e-8 son los defaults de PyTorch. Coincídelos para que el port drill sea limpio.
  • Contador global self.t. Se incrementa una vez por cada llamada a step(). PyTorch mantiene t por parámetro; comportamiento idéntico en la práctica ya que todos los params ven el mismo número de pasos.
  • La corrección de sesgo importa más temprano. En el paso t=1, 1 - β₁^1 = 0.1, así que m̂ = m / 0.1 = 10 · m. Sin corrección, la actualización sería demasiado pequeña. Tras ~50 pasos, la corrección es despreciable.

¿Por qué corrección de sesgo?

Inicializa m_0 = 0. Tras un paso: m_1 = (1 - β₁) · g_1. Esto está sesgado hacia cero — la verdadera media móvil debería ser g_1, pero tenemos 0.1 · g_1. Dividir por 1 - β₁^1 = 0.1 restaura la estimación insesgada.

El argumento de la serie geométrica: m_t = (1 - β₁) · Σ_{k=1..t} β₁^{t-k} · g_k. Suma de pesos = (1 - β₁) · (1 - β₁^t) / (1 - β₁) = 1 - β₁^t. Así que m_t / (1 - β₁^t) es la media ponderada con pesos sumando 1 — la estimación insesgada de la media móvil.

Esta es la derivación que Borja debe reproducir en /quiz 09.

¿Por qué √v̂ en el denominador?

El denominador √v̂_t de Adam es una estimación de la magnitud RMS del gradiente por coordenada. Dividir la actualización por ello da tamaños de paso adaptativos por coordenada: las coordenadas con gradientes consistentemente grandes obtienen tasas de aprendizaje efectivas más pequeñas; las coordenadas con gradientes pequeños obtienen mayores. La intuición: cada coordenada ve el mismo paso relativo efectivo.

El + ε previene la división por cero (lección de cancelación catastrófica de la Fase 2: nunca dividas por algo que no hayas acotado).

Cross-check contra PyTorch

La corrección de Adam es difícil de comprobar a ojo — pequeñas diferencias numéricas se acumulan. La DoD requiere que Adam coincida con torch.optim.Adam a 1e-5 sobre 100 pasos en una cuadrática.

def test_adam_matches_pytorch():
    rng = np.random.default_rng(0)
    init = rng.standard_normal(5).astype(np.float64)
    target = rng.standard_normal(5).astype(np.float64)

    # Ours
    p = Parameter(init.copy())
    optim = Adam([p], lr=1e-2)
    for _ in range(100):
        optim.zero_grad()
        loss = ((p - Tensor(target)) ** 2).sum()
        loss.backward()
        optim.step()

    # PyTorch
    pt = torch.tensor(init, dtype=torch.float64, requires_grad=True)
    opt_t = torch.optim.Adam([pt], lr=1e-2)
    target_t = torch.tensor(target, dtype=torch.float64)
    for _ in range(100):
        opt_t.zero_grad()
        loss_t = ((pt - target_t) ** 2).sum()
        loss_t.backward()
        opt_t.step()

    np.testing.assert_allclose(p.data, pt.detach().numpy(), atol=1e-5)

Si este test pasa, Adam es correcto.

Escollos comunes

  1. Reutilizar una lista de params obsoleta. Añade una capa al modelo después de crear el optimizador → los parámetros de la nueva capa no se actualizan. Crea siempre el optimizador después de que el modelo esté completamente construido. Documenta esto de forma visible.
  2. Olvidar optim.zero_grad(). Los gradientes se acumulan (la Fase 8 + Fase 9 son por defecto estilo acumulador). Sin zero_grad, cada paso usa la suma de todos los gradientes pasados — la pérdida diverge inmediatamente. El Lab 02 reproduce este bug deliberadamente.
  3. Llamar a optim.step() antes de cualquier backward. Ningún p.grad existe aún. La guarda if p.grad is None: continue previene un crash, pero es un signo de un error de lógica en el bucle de entrenamiento.
  4. El contador t de Adam no incrementado. Si olvidas self.t += 1, la corrección de sesgo es siempre para t=0, dando 1 - β^0 = 0, división por cero. Atrapa con el test de cross-check.
  5. Mezclar parámetros FP64 con gradientes FP32. Posible porque Tensor no impone consistencia de dtype. La Fase 9 es FP64 en todo por limpieza; la Fase 26 (precisión mixta) revisa.

Notas de rendimiento

Para un MLP de 469 parámetros, el overhead de Adam es: - 2 arrays extra por parámetro (m, v): 2 · 469 · 8 bytes = 7.5 KB. - Por paso: ~5 ops aritméticas por elemento de parámetro. ~2.5K ops. Microsegundos.

Despreciable. Para el transformer de la Fase 17 (~10M params), el mismo cálculo: 80 MB de estado extra, ~50M ops/paso. Aún despreciable comparado con el matmul. Adam es esencialmente gratis mientras el modelo quepa en memoria.

El Adam de PyTorch es funcionalmente idéntico al nuestro pero escrito en kernels CUDA. El ejercicio de Adam de la Fase 25 los compara.

Escollos (morderán en el lab)

  • Off-by-one en t. Incrementa antes de calcular 1 - β^t, no después.
  • Defaults de beta. (0.9, 0.999) no (0.999, 0.9). Fácil de intercambiar.
  • Colocación de eps. (np.sqrt(v_hat) + eps), no np.sqrt(v_hat + eps). El primero es más estable (evita sqrt(0) exactamente); ambos se usan en la literatura; PyTorch usa el primero; coincídelo.
  • Reutilizar la misma instancia de Adam después de reemplazar parámetros. Las claves del state dict (ids) se vuelven obsoletas. No hagas esto; crea un nuevo optimizador.
  • Adam con lr muy pequeño (1e-6) + corrección de sesgo: el primer paso efectivo sigue siendo grande debido a la corrección 1 / (1 - β₁) ≈ 10. El lab lo mostrará.

Ancla temática (§A13)

El entrenamiento de TenseMLP de la Fase 9 usa Adam(lr=1e-2) (defaults betas=(0.9, 0.999), eps=1e-8). Convergencia: <30 epochs hasta >90% de accuracy de entrenamiento en el set de 250 triplas; >85% en el set de validación de 50 triplas. El Lab 03 reporta la curva de pérdida y la matriz de confusión.

La cuenta de parámetros del MLP (469) es lo bastante pequeña como para que también puedas probar SGD(lr=0.5) (lr alto porque pocos params, pérdida bien condicionada) y ver convergencia comparable — útil para el plot diagnóstico "SGD vs Adam".

Lo que esta página NO cubre

  • Weight decay / regularización L2. Fase 10 (AdamW vs Adam-con-L2 — difieren en implementación).
  • Programación de la tasa de aprendizaje. Fase 18 (warmup, cosine decay).
  • Clipping de gradientes. Fase 18.
  • Checkpointing del estado del optimizador. Fase 18.
  • Optimizadores fragmentados (ZeRO). Fase 35 (entrenamiento distribuido).

Recapitulación de un párrafo

SGD es una actualización de 5 líneas: p.data -= lr * p.grad, opcionalmente con momentum (v ← μv + g; p ← p - lr·v). Adam es el mismo bucle con estimaciones móviles por parámetro del primer momento (m) y segundo momento (v), corregidos de sesgo por 1 - β^t, y la actualización normalizada por √v̂. Ambos caben en un método step() siguiendo la interfaz Optimizer(params, lr). El estado se almacena en optim.state[id(p)]. La DoD requiere que Adam coincida con PyTorch a 1e-5 sobre 100 pasos — esto atrapa cada bug común (orden de beta incorrecto, falta de corrección de sesgo, colocación incorrecta de eps) de una vez.


Siguiente: lab/00-parameter-and-module-skeleton.md.