Skip to content

English · Español

01 — Variance-Preserving Initialization (Xavier, Kaiming)

🇪🇸 Inicializar bien = elegir la varianza de los pesos para que la varianza de las activaciones no explote ni colapse al pasar por las capas. Xavier (lineal/tanh) y Kaiming (ReLU) son la misma derivación con un factor 2 de diferencia por culpa del cero en ReLU.


The setup we'll preserve

Consider a single layer:

\[ y = W x + b, \qquad W \in \mathbb{R}^{n_\text{out} \times n_\text{in}} \]

Followed by an elementwise activation \(\sigma\):

\[ z = \sigma(y) \]

We want a chain of these to be "well-behaved" at initialization — meaning, both:

  1. The activation variance doesn't explode or collapse layer by layer.
  2. The gradient variance doesn't explode or collapse layer by layer either.

If both hold, the forward pass is informative (signal preserved) and the backward pass is informative (gradient reaches every layer). That's the bar.

Assumptions (state them so you know when they break)

  1. \(W_{ij}\) are i.i.d., mean zero, variance \(\sigma_W^2\).
  2. \(b = 0\) at init.
  3. \(x_j\) are mean zero, variance \(\sigma_x^2\), independent of \(W\) and of each other across j.
  4. The activation \(\sigma\) is "nice" — either linear, or symmetric saturating like tanh, or one-sided like ReLU. (Different cases land below.)

These are textbook assumptions for the derivation; they're approximations in practice. The Glorot and He papers show empirically that the approximations are good enough to predict the right scale.

Forward-pass derivation (Xavier / Glorot)

Compute \(\mathrm{Var}(y_i)\):

\[ y_i = \sum_{j=1}^{n_\text{in}} W_{ij} x_j \]

Variance of a sum of independent terms is the sum of variances:

\[ \mathrm{Var}(y_i) = \sum_{j=1}^{n_\text{in}} \mathrm{Var}(W_{ij} x_j) \]

For independent mean-zero random variables, \(\mathrm{Var}(W_{ij} x_j) = \mathrm{Var}(W_{ij}) \cdot \mathrm{Var}(x_j) = \sigma_W^2 \sigma_x^2\). So:

\[ \boxed{\mathrm{Var}(y_i) = n_\text{in} \cdot \sigma_W^2 \cdot \sigma_x^2} \]

For a linear (or near-linear, like the linear region of tanh) activation, \(\mathrm{Var}(z_i) \approx \mathrm{Var}(y_i)\). To preserve variance (\(\mathrm{Var}(z) = \mathrm{Var}(x) = \sigma_x^2\)), set:

\[ \sigma_W^2 = \frac{1}{n_\text{in}} \qquad \Rightarrow \qquad W_{ij} \sim \mathcal{N}\!\left(0, \frac{1}{n_\text{in}}\right) \]

That's Xavier (forward variant).

Backward-pass derivation

Now think about the gradient at the input:

\[ \frac{\partial L}{\partial x_j} = \sum_{i=1}^{n_\text{out}} W_{ij} \frac{\partial L}{\partial y_i} \]

Same shape of argument:

\[ \mathrm{Var}\!\left(\frac{\partial L}{\partial x_j}\right) = n_\text{out} \cdot \sigma_W^2 \cdot \mathrm{Var}\!\left(\frac{\partial L}{\partial y_i}\right) \]

To preserve gradient variance: \(\sigma_W^2 = 1/n_\text{out}\).

The compromise

Forward wants \(\sigma_W^2 = 1/n_\text{in}\). Backward wants \(\sigma_W^2 = 1/n_\text{out}\). They disagree unless \(n_\text{in} = n_\text{out}\). Glorot's compromise is the harmonic mean:

\[ \boxed{\sigma_W^2 = \frac{2}{n_\text{in} + n_\text{out}}} \]

That's Xavier-Glorot (compromise variant), used by default in PyTorch's nn.Linear for xavier_normal_.

Kaiming / He for ReLU

For \(\sigma = \text{ReLU}\), half the units (in expectation under symmetric \(y_i\)) are zero. The post-activation variance is roughly half the pre-activation variance:

\[ \mathrm{Var}(\text{ReLU}(y_i)) \approx \tfrac{1}{2} \mathrm{Var}(y_i) \]

(For a symmetric mean-zero distribution, exactly half by symmetry.) Plug in to the forward derivation: to preserve variance through ReLU, we need \(\sigma_W^2 \cdot n_\text{in}\) to be twice as large as in the linear case:

\[ \boxed{\sigma_W^2 = \frac{2}{n_\text{in}} \qquad \text{(Kaiming/He, forward)}} \]

This is what kaiming_normal_ defaults to. The backward variant is \(2/n_\text{out}\); the compromise is rare in practice — most code uses the forward variant because ReLU networks are usually trained forward-stable first.

Worked example — a 4-layer MLP with input variance 1

Take an MLP 784 → 256 → 256 → 64 → 10. Inputs are MNIST pixels normalized to variance 1.

Bad init: W ~ N(0, 1) (i.e., σ_W² = 1).

Forward variance after each linear layer (linear activation for now):

Var after layer 1 = 784 × 1 × 1   = 784
Var after layer 2 = 256 × 1 × 784 ≈ 2 × 10⁵
Var after layer 3 = 256 × 1 × 2e5 ≈ 5 × 10⁷
Var after layer 4 = 64  × 1 × 5e7 ≈ 3 × 10⁹

The logit variance is ~10⁹. Softmax saturates instantly; cross-entropy gradients are zero everywhere except at the argmax. The network can't train.

Xavier init: σ_W² = 1/n_in:

Var after layer 1 = 784 × (1/784) × 1 = 1
Var after layer 2 = 256 × (1/256) × 1 = 1
Var after layer 3 = 256 × (1/256) × 1 = 1
Var after layer 4 = 64  × (1/64)  × 1 = 1

Variance preserved exactly. Logits are O(1). Softmax behaves. Gradients are nontrivial. The network trains.

Kaiming for ReLU: same calculation with the factor-2 correction; the post-ReLU variance stays at 1 across layers.

Lab 00 (variance-walk) reproduces this with measured numbers on Borja's machine.

What this argument doesn't cover

  1. Nonlinear activations beyond ReLU/tanh. GeLU, SiLU, Swish — each has its own effective "gain" \(g\) such that \(\mathrm{Var}(\sigma(y)) \approx g \cdot \mathrm{Var}(y)\). Use \(\sigma_W^2 = 1/(g \cdot n_\text{in})\). The gain table for GeLU ≈ 1.7; for SiLU ≈ 1.7; for tanh ≈ 5/3 (with appropriate input scaling). PyTorch's nn.init.calculate_gain encodes this.
  2. Convolutions. \(n_\text{in}\) becomes \(\text{kernel\_size}^2 \cdot \text{in\_channels}\). Same math, different fan-in.
  3. Embeddings. No fan-in in the same sense; we use a small fixed std (e.g., \(0.02\), the GPT-2 convention). Phase 13 discusses.
  4. Bias init. Almost always zero. Non-zero bias init breaks the symmetric mean-zero assumption.

What goes wrong if you pick the wrong one

Wrong choice Symptom
Xavier with ReLU Variance halves every layer → activations collapse → flat loss curve.
Kaiming with tanh Variance doubles every layer (the factor-2 isn't needed for tanh) → saturation.
Forward-only variant with very wide output Gradient explodes in backward through large n_out.
Uniform [-1, 1] regardless of n_in Variance scales with n_in → instant explosion past 4 layers.

Lab 01 ablates the first two on Borja's MLP and shows the loss curves.

Practical implementation notes

  • Uniform vs Normal. Both are used in the wild. Uniform with bounds ±sqrt(3 σ_W²) matches Normal's variance. The shape of the distribution barely matters compared to its variance.
  • Bias. Initialize to 0 (or 0.01 for ReLU "lottery ticket" arguments — but it's marginal).
  • fan_in vs fan_out vs fan_avg. In a Linear(in_features=N, out_features=M), \(n_\text{in} = N\), \(n_\text{out} = M\), \(n_\text{avg} = (N+M)/2\). PyTorch's defaults vary by function — read the source if it matters.
  • Seeded init. All init calls take a seed (or RNG). The seed-fixture in tests/conftest.py makes init deterministic across the test suite.

Drill problems (work before lab 01)

Solutions in solutions/01-initialization-ref.md (phase open). Try by reasoning.

  1. A 10-layer network with tanh activation is initialized with \(\sigma_W^2 = 0.5/n_\text{in}\). After 10 layers, what's the activation variance ratio to the input?
  2. A network alternates Linear → ReLU → Linear → ReLU. Each Linear has the same \(n_\text{in} = n_\text{out} = 256\). What \(\sigma_W^2\) preserves activation variance across 20 layers?
  3. PyTorch's nn.Linear.reset_parameters() uses kaiming_uniform_(weight, a=sqrt(5)) by default. What's the effective gain for a=sqrt(5), and why is this an odd default for vanilla MLPs?

One-paragraph recap

Initialization is the variance choice \(\sigma_W^2\) for the weight matrix so the chain of layers preserves activation variance (forward) and gradient variance (backward). For linear/tanh, \(\sigma_W^2 = 1/n_\text{in}\) (Xavier). For ReLU, \(\sigma_W^2 = 2/n_\text{in}\) (Kaiming) because ReLU zeros out half the units. The forward and backward versions disagree; Glorot's harmonic-mean compromise \(2/(n_\text{in}+n_\text{out})\) is a common default. Wrong init manifests as forward-pass blow-up or collapse before training even starts.


Next: theory/02-normalization.md.