Skip to content

English · Español

Break 00 — Remove the sqrt(d_k) scaling in attention

🇪🇸 Quitamos el / sqrt(d_k) del cálculo de los logits de atención. A d_k pequeño no pasa nada visible. A d_k = 64 la varianza de los logits crece × 64, el softmax satura, los gradientes mueren. Demostramos el efecto con el §A13 y d_k ∈ {4, 16, 64}.

Anchors: LYNX_CORTEX.md §4 / PHASE 15; theory §02 scaled dot-product; theory §05 worked length-4; .claude/commands/break.md.


The break

In src/minimodel/nn/attention.py:

class ScaledDotProductAttention(Module):
    def forward(self, Q: Tensor, K: Tensor, V: Tensor, mask: Tensor | None = None) -> Tensor:
        d_k = Q.shape[-1]
        # BUG: removed the /sqrt(d_k) scaling.
        scores = Q @ K.transpose(-2, -1)
        # was: scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores + mask
        attn = scores.softmax(dim=-1)
        return attn @ V

Single-line edit.

Predict, then run

The pre-softmax logits have variance proportional to d_k. With Q, K ~ N(0, 1) (init), each entry of Q K^T is a sum of d_k products of standard normals, so:

\[ \mathrm{Var}((Q K^T)_{ij}) = d_k \]

After softmax with logits of variance d_k, the distribution is sharper by a factor exp(sqrt(d_k))-ish. Quantitatively, for one large logit z_max and d_k - 1 smaller ones, softmax peaks to 1 - O(exp(-z_max)). At d_k = 64, z_max ≈ sqrt(64) = 8, so exp(-8) ≈ 3e-4 — the softmax row is [0.9997, ε, ε, ..., ε]saturated.

A saturated softmax has near-zero gradient: ∂softmax_i/∂z_j ≈ 0 for all i ≠ j. The whole attention block stops learning.

Predictions

  • d_k = 4: subtle. Attention sharper than scaled version but loss converges.
  • d_k = 16: noticeable. Loss curve plateaus 30-50% sooner.
  • d_k = 64: catastrophic. Loss barely moves; attention is one-hot from step 1.
  • Inspecting attn (the softmax output): rows look like [1.0, 0, 0, ...] instead of the expected [0.25, 0.30, 0.20, 0.25].

Write predictions in learners/borja/phase-15/notes/breaks.md before running.

Observe

just exp 15-train-attn --tag broken-no-scaling --d-k 64

Diagnostics:

  1. Plot attn.max(axis=-1) per row over training steps — should be near 1 immediately if broken.
  2. Plot loss curve for d_k ∈ {4, 16, 64} with and without scaling.
  3. Compute the gradient norm at the QKV projection at step 1 — should be ~0 in the broken d_k = 64 case.

Symptom Borja will see

  • d_k = 4: training works (this is why the bug is subtle on small models).
  • d_k = 16: training works but converges 30% slower.
  • d_k = 64: loss is flat. attn.max(axis=-1).min() > 0.99 (saturated).
  • Gradient through the attention block: vanishing.

Hidden cause (one sentence)

The Q K^T product has variance proportional to d_k; without /sqrt(d_k) normalization, the softmax saturates for any d_k ≥ ~16 and gradients vanish.

Hint cascade

  1. Print attn.max(axis=-1).mean() over training. Should be ~1/T at init for a healthy run.
  2. Compute the empirical variance of pre-softmax logits. Compare to d_k. What does Vaswani et al. 2017 §3.2.1 say should happen?
  3. Look at forward in attention.py. Is the matmul divided by sqrt(d_k)?

Fix diff

scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)

Why this teaches the concept

Vaswani et al. 2017 §3.2.1 derives the scaling factor precisely to control softmax saturation as d_k grows. The bug is the most common subtle attention bug because it passes tests at d_k=8 and breaks at d_k=64. The break drives home that sqrt(d_k) isn't a polish — it's a load-bearing piece of variance preservation that becomes essential at any realistic model size. Phase 17's mini-GPT uses d_k = 64 per head; Phase 27's Flash attention preserves this scaling inside its tiled algorithm.


Next: Phase 16's /break on shuffled positional encodings.