English · Español
Break 00 — Remove the sqrt(d_k) scaling in attention¶
🇪🇸 Quitamos el
/ sqrt(d_k)del cálculo de los logits de atención. Ad_kpequeño no pasa nada visible. Ad_k = 64la varianza de los logits crece × 64, el softmax satura, los gradientes mueren. Demostramos el efecto con el §A13 yd_k ∈ {4, 16, 64}.Anchors:
LYNX_CORTEX.md§4 / PHASE 15; theory §02 scaled dot-product; theory §05 worked length-4;.claude/commands/break.md.
The break¶
In src/minimodel/nn/attention.py:
class ScaledDotProductAttention(Module):
def forward(self, Q: Tensor, K: Tensor, V: Tensor, mask: Tensor | None = None) -> Tensor:
d_k = Q.shape[-1]
# BUG: removed the /sqrt(d_k) scaling.
scores = Q @ K.transpose(-2, -1)
# was: scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores + mask
attn = scores.softmax(dim=-1)
return attn @ V
Single-line edit.
Predict, then run¶
The pre-softmax logits have variance proportional to d_k. With Q, K ~ N(0, 1) (init), each entry of Q K^T is a sum of d_k products of standard normals, so:
After softmax with logits of variance d_k, the distribution is sharper by a factor exp(sqrt(d_k))-ish. Quantitatively, for one large logit z_max and d_k - 1 smaller ones, softmax peaks to 1 - O(exp(-z_max)). At d_k = 64, z_max ≈ sqrt(64) = 8, so exp(-8) ≈ 3e-4 — the softmax row is [0.9997, ε, ε, ..., ε] — saturated.
A saturated softmax has near-zero gradient: ∂softmax_i/∂z_j ≈ 0 for all i ≠ j. The whole attention block stops learning.
Predictions¶
d_k = 4: subtle. Attention sharper than scaled version but loss converges.d_k = 16: noticeable. Loss curve plateaus 30-50% sooner.d_k = 64: catastrophic. Loss barely moves; attention is one-hot from step 1.- Inspecting
attn(the softmax output): rows look like[1.0, 0, 0, ...]instead of the expected[0.25, 0.30, 0.20, 0.25].
Write predictions in learners/borja/phase-15/notes/breaks.md before running.
Observe¶
Diagnostics:
- Plot
attn.max(axis=-1)per row over training steps — should be near 1 immediately if broken. - Plot loss curve for
d_k ∈ {4, 16, 64}with and without scaling. - Compute the gradient norm at the QKV projection at step 1 — should be ~0 in the broken
d_k = 64case.
Symptom Borja will see¶
d_k = 4: training works (this is why the bug is subtle on small models).d_k = 16: training works but converges 30% slower.d_k = 64: loss is flat.attn.max(axis=-1).min() > 0.99(saturated).- Gradient through the attention block: vanishing.
Hidden cause (one sentence)¶
The Q K^T product has variance proportional to d_k; without /sqrt(d_k) normalization, the softmax saturates for any d_k ≥ ~16 and gradients vanish.
Hint cascade¶
- Print
attn.max(axis=-1).mean()over training. Should be ~1/Tat init for a healthy run. - Compute the empirical variance of pre-softmax logits. Compare to
d_k. What does Vaswani et al. 2017 §3.2.1 say should happen? - Look at
forwardinattention.py. Is the matmul divided bysqrt(d_k)?
Fix diff¶
Why this teaches the concept¶
Vaswani et al. 2017 §3.2.1 derives the scaling factor precisely to control softmax saturation as d_k grows. The bug is the most common subtle attention bug because it passes tests at d_k=8 and breaks at d_k=64. The break drives home that sqrt(d_k) isn't a polish — it's a load-bearing piece of variance preservation that becomes essential at any realistic model size. Phase 17's mini-GPT uses d_k = 64 per head; Phase 27's Flash attention preserves this scaling inside its tiled algorithm.
Next: Phase 16's /break on shuffled positional encodings.