Skip to content

English · Español

02 — Normalización: BatchNorm, LayerNorm, RMSNorm

🇪🇸 Normalizar = re-escalar (y a veces re-centrar) las activaciones para que la capa siguiente vea una distribución estable. BatchNorm lo hace sobre el batch (visión); LayerNorm sobre las features (transformers clásicos); RMSNorm es LayerNorm sin la media, y es lo que usan los LLMs modernos.


Lo que todas resuelven

Una vez que la inicialización es correcta (teoría 01 de la Fase 10), la varianza del forward pass se preserva aproximadamente en la capa 0. Tras unos pocos pasos de entrenamiento, los pesos se mueven; la distribución de activaciones deriva; la siguiente capa ve un objetivo en movimiento. Esto es el internal covariate shift — un nombre de paper para una molestia real. La normalización lo arregla re-normalizando las activaciones en cada capa, independientemente de cómo hayan derivado los pesos.

Hay tres variantes de normalización que debes conocer:

  • BatchNorm (Ioffe & Szegedy 2015) — visión.
  • LayerNorm (Ba et al. 2016) — transformers.
  • RMSNorm (Zhang & Sennrich 2019) — LLMs modernos.

Difieren en sobre qué eje normalizan. La matemática es por lo demás casi idéntica.

BatchNorm

Dado un batch de activaciones \(X \in \mathbb{R}^{B \times d}\) (batch × features), BatchNorm normaliza sobre el eje del batch, por feature:

\[ \mu_k = \frac{1}{B} \sum_{i=1}^{B} X_{ik}, \qquad \sigma_k^2 = \frac{1}{B} \sum_{i=1}^{B} (X_{ik} - \mu_k)^2 \]
\[ \hat{X}_{ik} = \frac{X_{ik} - \mu_k}{\sqrt{\sigma_k^2 + \epsilon}}, \qquad Y_{ik} = \gamma_k \hat{X}_{ik} + \beta_k \]

Los parámetros aprendibles \(\gamma, \beta \in \mathbb{R}^d\) permiten a la red deshacer la normalización si quiere (inicializados \(\gamma = 1, \beta = 0\)).

Entrenamiento vs inferencia. En entrenamiento, \(\mu_k, \sigma_k\) vienen del batch actual. En inferencia, te gustaría ser invariante al tamaño del batch, así que BN mantiene medias móviles de \(\mu_k, \sigma_k\) durante el entrenamiento y las usa en inferencia. Esta es la fuente de bugs: olvidar model.eval() es uno de los 10 footguns principales del aprendizaje profundo. Lo mencionamos aunque nuestro currículo no usará BN.

Por qué BN funciona en visión pero no en transformers. Los batches de visión son grandes, i.i.d., y las features (posiciones espaciales × canales) son aproximadamente estacionarias entre muestras. Los batches de lenguaje son pequeños, las secuencias están correlacionadas dentro de un ejemplo, y la semántica de las features cambia entre posiciones de la secuencia. El ruido de las estadísticas de batch de BatchNorm domina la señal para lenguaje. El paper del transformer cambió a LayerNorm por esa razón.

LayerNorm

LayerNorm normaliza sobre el eje de features, por muestra:

\[ \mu^{(i)} = \frac{1}{d} \sum_{k=1}^{d} X_{ik}, \qquad {\sigma^{(i)}}^2 = \frac{1}{d} \sum_{k=1}^{d} (X_{ik} - \mu^{(i)})^2 \]
\[ \hat{X}_{ik} = \frac{X_{ik} - \mu^{(i)}}{\sqrt{{\sigma^{(i)}}^2 + \epsilon}}, \qquad Y_{ik} = \gamma_k \hat{X}_{ik} + \beta_k \]

Fíjate en el índice. \(\mu, \sigma\) dependen de la muestra i, no del batch. Por tanto:

  • Misma computación en entrenamiento e inferencia. Sin cambio .eval() necesario. Sin fuente de bugs.
  • Sin dependencia del batch. Funciona con tamaño de batch 1 (inferencia, chat).

LayerNorm se convirtió en el default del transformer, sin más, durante una década.

RMSNorm

En 2019, Zhang & Sennrich propusieron eliminar la sustracción de la media:

\[ \text{RMS}(X^{(i)}) = \sqrt{\frac{1}{d} \sum_{k=1}^{d} X_{ik}^2 + \epsilon} \]
\[ Y_{ik} = \gamma_k \cdot \frac{X_{ik}}{\text{RMS}(X^{(i)})} \]

Sin media. Sin parámetro de sesgo \(\beta\) (algunas implementaciones mantienen \(\beta\), la mayoría no). Dos efectos:

  1. La mitad de FLOPs y la mitad de memory traffic. Sin cálculo de media, sin sustracción de media, sin \(\beta\). En el i5-8250U de Borja, espera ~30–40% de reducción del wall-time para la propia operación de norm.
  2. Estabilidad de entrenamiento empíricamente marginal o mejor. El paper original de RMSNorm mostró convergencia comparable; el trabajo posterior en LLMs (Llama, T5-1.1, PaLM) lo confirmó: quitar la media no perjudica.

¿Por qué quitar la media no perjudica? Sin demostración definitiva. El argumento empírico: tras una suma residual x + f(x) con f de media cero, la media de x + f(x) es aproximadamente la media de x — ya controlada por la normalización de la capa anterior. Re-centrar en cada capa es trabajo redundante. El re-escalado es lo que realmente importa para la preservación de varianza en la siguiente capa.

Aceptamos esto empíricamente. RMSNorm es el default en los stacks de LLM modernos; la implementamos como la norm primaria del currículo.

Pre-LN vs Post-LN

Dónde pones la normalización en un bloque residual importa.

Post-LN (transformer original): $$ y = \text{LN}(x + f(x)) $$

Pre-LN (transformer moderno): $$ y = x + f(\text{LN}(x)) $$

Pre-LN se hizo dominante porque da un camino de gradiente más limpio:

  • En Pre-LN, el gradiente ∂y/∂x = I + ∂(f ∘ LN)/∂x. El término identidad es incondicional — la autopista residual siempre tiene ganancia unitaria.
  • En Post-LN, el gradiente ∂y/∂x = ∂LN/∂(x+f(x)) × (I + ∂f/∂x). El Jacobiano del LayerNorm escala el término residual; para redes profundas este factor multiplicativo se compone y desestabiliza el entrenamiento.

La consecuencia empírica: los transformers Pre-LN entrenan sin warmup de learning rate; los transformers Post-LN necesitan warmup para evitar divergencia temprana. La Fase 17 (bloques transformer) profundiza. Para la Fase 10, la regla es: usar Pre-LN en todo lo que construyamos.

Posición de ε (territorio del subagent Numerical Stability)

El denominador es \(\sqrt{\sigma^2 + \epsilon}\), NO \(\sqrt{\sigma^2} + \epsilon\). Por qué:

  • Para \(\sigma^2 \to 0\) muy pequeña, \(\sqrt{\sigma^2} \to 0\), así que la división explota antes de que \(\epsilon\) pueda salvarte.
  • \(\sqrt{\sigma^2 + \epsilon}\) nunca baja de \(\sqrt{\epsilon}\) — la \(\epsilon\) actúa como un suelo de varianza, que es la semántica pretendida.

Un bug de implementación común es 1.0 / (np.sqrt(var) + eps). Pasa los tests con entradas no degeneradas y falla silenciosamente en casos límite. El Lab 02 tiene un test unitario que dispara \(\sigma \to 0\) para exponer esto.

Valores típicos de \(\epsilon\): \(10^{-5}\) (LayerNorm), \(10^{-6}\) (RMSNorm). El ecosistema LLM no es religioso al respecto; revisa la model card.

Un diagrama de cordura

                BatchNorm                LayerNorm / RMSNorm
                ─────────                ─────────────────
                normalize OVER:           normalize OVER:
                ┌──────────┐              ┌──────────┐
                │  Batch   │              │ Features │
                └──────────┘              └──────────┘

X.shape = (B, d)                       X.shape = (B, d)
μ.shape = (d,)                          μ.shape = (B,)
σ.shape = (d,)                          σ.shape = (B,)

Sólo las shapes te dicen la norm: si tu μ tiene shape (d,), es BatchNorm; (B,) significa LayerNorm/RMSNorm. (O, en un modelo de secuencias con shape (B, L, d), LayerNorm tiene μ.shape = (B, L).)

Coste computacional — números reales

Para shape (B, L, d) = (64, 512, 768) (paso típico de entrenamiento de un transformer pequeño):

  • LayerNorm: 1 reducción de media + 1 reducción de varianza + 1 sustracción + 1 división + 1 afín = ~5d FLOPs por elemento.
  • RMSNorm: 1 reducción de media de cuadrados + 1 división + 1 escala = ~3d FLOPs por elemento.

Razón: RMSNorm hace alrededor del 60% del cómputo de LayerNorm. La razón de memory traffic es similar. En el argumento del roofline de la Fase 1 (las operaciones de norm están limitadas por memoria), los ahorros se traducen en ~30–40% de wall-clock.

Lo que esta fase no implementa

  • GroupNorm — entre BN y LN; se usa en algunos modelos de visión-lenguaje. Omitir.
  • InstanceNorm — se usa en transferencia de estilo. Omitir.
  • Weight standardization — truco ortogonal que normaliza pesos, no activaciones. Fase 21.
  • DyT (dynamic tanh) — una alternativa a LayerNorm de 2024+. Lectura rápida si te interesa; fuera de alcance.

Problemas de práctica

Soluciones en solutions/02-normalization-ref.md (apertura de fase).

  1. Muestra que para LayerNorm con \(d = 1\), la salida es siempre exactamente \(\beta_1\) (sin importar la entrada). ¿Por qué LayerNorm con d = 1 es inútil?
  2. Muestra que el gradiente de LayerNorm con respecto a una única entrada \(X_{ik}\) depende de todas las demás features \(X_{ij}\) (j ≠ k) a través de la media y la varianza. Contrástalo con RMSNorm, donde la estructura de dependencia es la misma (todavía todas las features) pero la expresión es más simple.
  3. Una red usa Post-LN sin warmup a learning rate 1e-3 y diverge en 100 pasos. Sugiere dos arreglos de un único cambio (que no sean reducir el LR) y explica qué modo de fallo aborda cada uno.

Resumen en un párrafo

La normalización re-centra y/o re-escala activaciones de modo que la siguiente capa vea una distribución estable. BatchNorm normaliza sobre el eje del batch (visión; divergencia entrenamiento/inferencia). LayerNorm normaliza sobre las features por muestra (transformers; sin divergencia entrenamiento/inferencia). RMSNorm elimina la sustracción de media de LayerNorm (LLMs modernos; ~40% más barato, estabilidad empíricamente equivalente). Pre-LN (y = x + f(LN(x))) es el default moderno porque da una autopista de gradiente más limpia que Post-LN. \(\epsilon\) va dentro de la raíz para estabilidad numérica.


Siguiente: theory/03-residuals.md.