Skip to content

English · Español

01 — Inicialización que preserva la varianza (Xavier, Kaiming)

🇪🇸 Inicializar bien = elegir la varianza de los pesos para que la varianza de las activaciones no explote ni colapse al pasar por las capas. Xavier (lineal/tanh) y Kaiming (ReLU) son la misma derivación con un factor 2 de diferencia por culpa del cero en ReLU.


El montaje que queremos preservar

Considera una única capa:

\[ y = W x + b, \qquad W \in \mathbb{R}^{n_\text{out} \times n_\text{in}} \]

Seguida por una activación elemento a elemento \(\sigma\):

\[ z = \sigma(y) \]

Queremos que una cadena de estas se comporte "bien" en la inicialización — es decir, ambos:

  1. La varianza de la activación no explota ni colapsa capa por capa.
  2. La varianza del gradiente tampoco explota ni colapsa capa por capa.

Si ambas se cumplen, el forward pass es informativo (señal preservada) y el backward pass es informativo (el gradiente llega a cada capa). Ese es el objetivo.

Supuestos (decláralos para saber cuándo se rompen)

  1. \(W_{ij}\) son i.i.d., media cero, varianza \(\sigma_W^2\).
  2. \(b = 0\) en la inicialización.
  3. \(x_j\) tienen media cero, varianza \(\sigma_x^2\), independientes de \(W\) y entre sí a través de j.
  4. La activación \(\sigma\) es "razonable" — o lineal, o saturante simétrica como tanh, o unilateral como ReLU. (Los distintos casos aparecen abajo.)

Son supuestos de libro de texto para la derivación; en la práctica son aproximaciones. Los papers de Glorot y He muestran empíricamente que las aproximaciones son lo bastante buenas para predecir la escala correcta.

Derivación del forward pass (Xavier / Glorot)

Calcula \(\mathrm{Var}(y_i)\):

\[ y_i = \sum_{j=1}^{n_\text{in}} W_{ij} x_j \]

La varianza de una suma de términos independientes es la suma de varianzas:

\[ \mathrm{Var}(y_i) = \sum_{j=1}^{n_\text{in}} \mathrm{Var}(W_{ij} x_j) \]

Para variables aleatorias independientes con media cero, \(\mathrm{Var}(W_{ij} x_j) = \mathrm{Var}(W_{ij}) \cdot \mathrm{Var}(x_j) = \sigma_W^2 \sigma_x^2\). Por tanto:

\[ \boxed{\mathrm{Var}(y_i) = n_\text{in} \cdot \sigma_W^2 \cdot \sigma_x^2} \]

Para una activación lineal (o casi lineal, como la región lineal de tanh), \(\mathrm{Var}(z_i) \approx \mathrm{Var}(y_i)\). Para preservar la varianza (\(\mathrm{Var}(z) = \mathrm{Var}(x) = \sigma_x^2\)), fija:

\[ \sigma_W^2 = \frac{1}{n_\text{in}} \qquad \Rightarrow \qquad W_{ij} \sim \mathcal{N}\!\left(0, \frac{1}{n_\text{in}}\right) \]

Eso es Xavier (variante forward).

Derivación del backward pass

Ahora piensa en el gradiente en la entrada:

\[ \frac{\partial L}{\partial x_j} = \sum_{i=1}^{n_\text{out}} W_{ij} \frac{\partial L}{\partial y_i} \]

Argumento de la misma forma:

\[ \mathrm{Var}\!\left(\frac{\partial L}{\partial x_j}\right) = n_\text{out} \cdot \sigma_W^2 \cdot \mathrm{Var}\!\left(\frac{\partial L}{\partial y_i}\right) \]

Para preservar la varianza del gradiente: \(\sigma_W^2 = 1/n_\text{out}\).

El compromiso

El forward quiere \(\sigma_W^2 = 1/n_\text{in}\). El backward quiere \(\sigma_W^2 = 1/n_\text{out}\). Discrepan salvo que \(n_\text{in} = n_\text{out}\). El compromiso de Glorot es la media armónica:

\[ \boxed{\sigma_W^2 = \frac{2}{n_\text{in} + n_\text{out}}} \]

Eso es Xavier-Glorot (variante de compromiso), usada por defecto en xavier_normal_ de nn.Linear de PyTorch.

Kaiming / He para ReLU

Para \(\sigma = \text{ReLU}\), la mitad de las unidades (en esperanza bajo \(y_i\) simétrica) son cero. La varianza post-activación es aproximadamente la mitad de la varianza pre-activación:

\[ \mathrm{Var}(\text{ReLU}(y_i)) \approx \tfrac{1}{2} \mathrm{Var}(y_i) \]

(Para una distribución simétrica con media cero, exactamente la mitad por simetría.) Sustituye en la derivación del forward: para preservar la varianza a través de ReLU, necesitamos que \(\sigma_W^2 \cdot n_\text{in}\) sea el doble de grande que en el caso lineal:

\[ \boxed{\sigma_W^2 = \frac{2}{n_\text{in}} \qquad \text{(Kaiming/He, forward)}} \]

Esto es a lo que kaiming_normal_ apunta por defecto. La variante backward es \(2/n_\text{out}\); el compromiso es raro en la práctica — la mayoría del código usa la variante forward porque las redes ReLU normalmente se entrenan primero estables en forward.

Ejemplo trabajado — un MLP de 4 capas con varianza de entrada 1

Toma un MLP 784 → 256 → 256 → 64 → 10. Las entradas son píxeles MNIST normalizados a varianza 1.

Mala inicialización: W ~ N(0, 1) (es decir, σ_W² = 1).

Varianza forward tras cada capa lineal (activación lineal por ahora):

Var after layer 1 = 784 × 1 × 1   = 784
Var after layer 2 = 256 × 1 × 784 ≈ 2 × 10⁵
Var after layer 3 = 256 × 1 × 2e5 ≈ 5 × 10⁷
Var after layer 4 = 64  × 1 × 5e7 ≈ 3 × 10⁹

La varianza de los logits es ~10⁹. El softmax satura al instante; los gradientes de cross-entropy son cero en todas partes excepto en el argmax. La red no puede entrenar.

Inicialización Xavier: σ_W² = 1/n_in:

Var after layer 1 = 784 × (1/784) × 1 = 1
Var after layer 2 = 256 × (1/256) × 1 = 1
Var after layer 3 = 256 × (1/256) × 1 = 1
Var after layer 4 = 64  × (1/64)  × 1 = 1

Varianza preservada exactamente. Los logits son O(1). El softmax se comporta. Los gradientes son no triviales. La red entrena.

Kaiming para ReLU: el mismo cálculo con la corrección de factor 2; la varianza post-ReLU se mantiene en 1 a través de las capas.

El Lab 00 (variance-walk) reproduce esto con números medidos en la máquina de Borja.

Lo que este argumento no cubre

  1. Activaciones no lineales más allá de ReLU/tanh. GeLU, SiLU, Swish — cada una tiene su propia "ganancia" efectiva \(g\) tal que \(\mathrm{Var}(\sigma(y)) \approx g \cdot \mathrm{Var}(y)\). Usa \(\sigma_W^2 = 1/(g \cdot n_\text{in})\). La tabla de ganancias para GeLU ≈ 1.7; para SiLU ≈ 1.7; para tanh ≈ 5/3 (con escalado de entrada apropiado). nn.init.calculate_gain de PyTorch codifica esto.
  2. Convoluciones. \(n_\text{in}\) se convierte en \(\text{kernel\_size}^2 \cdot \text{in\_channels}\). La misma matemática, distinto fan-in.
  3. Embeddings. No hay fan-in en el mismo sentido; usamos una desviación típica fija pequeña (p. ej., \(0.02\), la convención de GPT-2). La Fase 13 lo discute.
  4. Inicialización del bias. Casi siempre cero. La inicialización con bias no nulo rompe el supuesto de simetría con media cero.

Qué sale mal si eliges la equivocada

Elección errónea Síntoma
Xavier con ReLU La varianza se reduce a la mitad cada capa → las activaciones colapsan → curva de pérdida plana.
Kaiming con tanh La varianza se duplica cada capa (el factor 2 no se necesita para tanh) → saturación.
Variante sólo forward con salida muy ancha El gradiente explota en el backward por la n_out grande.
Uniforme [-1, 1] sin importar n_in La varianza escala con n_in → explosión instantánea pasadas 4 capas.

El Lab 01 ablaciona las dos primeras en el MLP de Borja y muestra las curvas de pérdida.

Notas prácticas de implementación

  • Uniforme vs Normal. Ambas se usan en la práctica. Uniforme con cotas ±sqrt(3 σ_W²) iguala la varianza de la Normal. La forma de la distribución apenas importa comparada con su varianza.
  • Bias. Inicializa a 0 (o 0.01 para argumentos de "lottery ticket" en ReLU — pero es marginal).
  • fan_in vs fan_out vs fan_avg. En un Linear(in_features=N, out_features=M), \(n_\text{in} = N\), \(n_\text{out} = M\), \(n_\text{avg} = (N+M)/2\). Los defaults de PyTorch varían según la función — lee la fuente si importa.
  • Inicialización con semilla (seed). Todas las llamadas de inicialización aceptan una semilla (o un RNG). La fixture de semilla en tests/conftest.py hace que la inicialización sea determinista en toda la suite de tests.

Problemas de práctica (resolver antes del lab 01)

Soluciones en solutions/01-initialization-ref.md (apertura de fase). Intenta razonando.

  1. Una red de 10 capas con activación tanh se inicializa con \(\sigma_W^2 = 0.5/n_\text{in}\). Tras 10 capas, ¿cuál es la razón de varianza de la activación respecto a la entrada?
  2. Una red alterna Linear → ReLU → Linear → ReLU. Cada Linear tiene la misma \(n_\text{in} = n_\text{out} = 256\). ¿Qué \(\sigma_W^2\) preserva la varianza de la activación a través de 20 capas?
  3. nn.Linear.reset_parameters() de PyTorch usa kaiming_uniform_(weight, a=sqrt(5)) por defecto. ¿Cuál es la gain efectiva para a=sqrt(5), y por qué es un default raro para MLPs vanilla?

Resumen en un párrafo

La inicialización es la elección de la varianza \(\sigma_W^2\) para la matriz de pesos de modo que la cadena de capas preserve la varianza de la activación (forward) y la varianza del gradiente (backward). Para lineal/tanh, \(\sigma_W^2 = 1/n_\text{in}\) (Xavier). Para ReLU, \(\sigma_W^2 = 2/n_\text{in}\) (Kaiming) porque ReLU pone a cero la mitad de las unidades. Las versiones forward y backward discrepan; el compromiso de media armónica de Glorot \(2/(n_\text{in}+n_\text{out})\) es un default común. La mala inicialización se manifiesta como explosión o colapso del forward pass antes de que el entrenamiento siquiera empiece.


Siguiente: theory/02-normalization.md.