Skip to content

English · Español

Lab 02 — Ablación de norm: BatchNorm vs LayerNorm vs RMSNorm

Objetivo: mostrar que LayerNorm y RMSNorm ambos entrenan; RMSNorm es medibles más barato.

Tiempo estimado: 90–120 minutos.

Prerrequisito: lab 01 commiteado; theory/02-normalization.md leído.


Qué produces

Un directorio experiments/10-norm-ablation/ que contiene:

  • train.py — script de entrenamiento con un flag CLI --norm.
  • losses.json — tres trayectorias de pérdida (sin-norm, LayerNorm, RMSNorm).
  • timings.json — tiempo wall-clock por paso para cada variante de norm.
  • loss_curves.png — tres curvas.
  • timing_bar.png — gráfico de barras del tiempo medio por paso por variante.
  • manifest.json.
  • README.md.

El montaje

Toma el MLP de 12 capas del lab 01. Fija la inicialización en Kaiming (la que entrena). Varía sólo la normalización:

  1. Sin norm (línea base; debería seguir entrenando con Kaiming, sólo que menos estable).
  2. LayerNorm antes de cada ReLU.
  3. RMSNorm antes de cada ReLU.

Estás mostrando dos cosas:

a. Tanto LayerNorm como RMSNorm entrenan, y producen una pérdida final similar. b. RMSNorm es más rápido por paso. Medible en el i5-8250U de Borja para una red de 12 capas con oculta 256.

TODOs

Bloque A — implementa los norms

  • Escribe src/minigrad/nn/norm.py con los módulos LayerNorm(d) y RMSNorm(d).
  • LayerNorm: \(\mu = \text{mean}(x, \text{axis=-1})\), \(\sigma^2 = \text{var}(x, \text{axis=-1})\), \(\hat x = (x - \mu)/\sqrt{\sigma^2 + \epsilon}\), \(y = \gamma \hat x + \beta\).
  • RMSNorm: \(\text{rms} = \sqrt{\text{mean}(x^2, \text{axis=-1}) + \epsilon}\), \(y = \gamma \cdot x / \text{rms}\).
  • ε va dentro del sqrt. Test unitario: la norm de una entrada toda-ceros no produce NaN.
  • Ambos módulos devuelven Tensors que retropropagan correctamente. Gradcheck a ambos.

Bloque B — tres ejecuciones

  • Mismo MLP de 12 capas. Inicialización Kaiming.
  • Variante 1: sin norm.
  • Variante 2: LayerNorm antes de cada ReLU. Aplica estilo Pre: out = ReLU(LayerNorm(x)).
  • Variante 3: RMSNorm antes de cada ReLU. Misma forma.
  • Cronometra cada paso con time.perf_counter_ns().

Bloque C — gráfico

  • Curvas de pérdida (3 líneas).
  • Gráfico de barras del tiempo medio por paso (3 barras). Incluye barras de error con la desviación típica del cronometraje.

Bloque D — interpreta

En README.md:

  1. ¿Entrenan las tres? Sí/no por variante. (Sí/sí/sí es lo esperado con Kaiming.)
  2. ¿Es la pérdida final la misma? Cuantifica.
  3. ¿Es RMSNorm más rápido que LayerNorm? ¿En qué porcentaje? Coteja con la predicción de la teoría 02 (~30–40% más rápido).
  4. ¿Cuál es la forma de la trayectoria de pérdida de la línea base sin-norm? ¿Suave o nerviosa? ¿Por qué?

Bloque E — manifest

Estándar. Incluye el valor de epsilon y la colocación de la norm (Pre vs Post).

Restricciones

  • Sólo estilo Pre. No estamos probando Pre vs Post aquí (eso es de hecho el lab 03 indirectamente vía residual + norm).
  • mypy --strict sobre src/minigrad/nn/norm.py.
  • Test de propiedad: para entradas fp32 aleatorias de formas diversas, el mean(x²) por fila de la salida de RMSNorm es igual a mean(γ²) (puesto que la salida es γ · x / rms y mean((γ·x/rms)²) = γ² · mean(x²)/rms² = γ² · 1 = mean(γ²) si γ es un vector de valores idénticos; para γ no uniforme, comprueba componente a componente).
  • Un único hilo, governor performance.

Condiciones de parada

Hecho cuando:

  1. Los siete archivos están commiteados.
  2. src/minigrad/nn/norm.py está limpio en mypy --strict y el gradcheck pasa.
  3. Las curvas de pérdida muestran RMSNorm y LayerNorm convergiendo a pérdida final similar.
  4. La barra de cronometraje muestra RMSNorm estrictamente más rápido que LayerNorm (el gap absoluto depende de tu máquina; debe ser > 0).
  5. El README responde las cuatro preguntas del Bloque D.

Trampas

  • LayerNorm sin γ, β aprendibles. Eso no es LayerNorm; eso es sólo estandarización. Los parámetros afines son parte de la definición.
  • RMSNorm con sustracción de media. Eso es LayerNorm. Bug típico de copy-paste.
  • NaN en el primer batch. Comprueba la colocación de ε (debe estar dentro del sqrt).
  • RMSNorm más lento que LayerNorm en tu medición. Probablemente sea un problema de overhead NumPy/minigrad, no coste real. Perfila una op de norm aislada. El ahorro debe ser visible con hidden_dim ≥ 256.
  • El gradcheck falla para RMSNorm. Suele faltar el cierre de 1/rms. El paso hacia atrás a través de 1/sqrt(mean(x²)+ε) es: $\(\frac{\partial}{\partial x_i} \frac{1}{\sqrt{m + \epsilon}} = -\frac{1}{2 (m+\epsilon)^{3/2}} \cdot \frac{2 x_i}{d} = -\frac{x_i}{d (m+\epsilon)^{3/2}}\)$ Asegúrate de que ese término está en tu paso hacia atrás.

Pista de último recurso

Si el gradcheck para LayerNorm falla: el gradiente de (x - mean(x)) / std(x) respecto a un único \(x_j\) depende de todos los \(x_i\) vía la media y la desviación típica. Hay tres rutas: por el \(x_j\) explícito, por la media, por la desviación típica. La implementación de LayerNorm de PyTorch tiene esto como una expresión en forma cerrada; recomponlo desde la definición y encontrarás tu término que falta.

Cuándo consultar solutions/

Tras los siete archivos. Solución: solutions/02-norm-ablation-ref.md (fase abierta).


Siguiente lab: lab/03-residual-depth.md.