Skip to content

English · Español

02 — Scaled Dot-Product Attention: The Full Derivation

Try it — a causal attention heatmap

🇪🇸 La ecuación central del transformer es \(\text{softmax}(Q K^\top / \sqrt{d_k}) V\). Aquí derivamos cada pieza: por qué producto escalar (similitud rotacionalmente invariante), por qué softmax (convertir similitudes en distribución), por qué dividir entre \(\sqrt{d_k}\) (varianza unitaria), y la reescritura numéricamente estable que toda implementación real usa (restar el máximo antes de exponenciar). Lee este archivo dos veces.

This file is the densest in Phase 15. We derive every term of

\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V \]

from first principles. Five pieces: dot-product similarity, the \(T \times T\) score matrix, the \(\sqrt{d_k}\) scaling, row-wise softmax, the final matmul with V.


Piece 1 — Dot product as similarity

We need to score, for each query \(Q_i \in \mathbb{R}^{d_k}\) and each key \(K_j \in \mathbb{R}^{d_k}\), "how relevant is \(j\) to \(i\)?". Many choices for the scoring function exist (cosine similarity, additive, bilinear, learned MLP — the Bahdanau attention from 2015 used an additive MLP). The transformer paper chose the dot product:

\[ \text{score}(Q_i, K_j) = Q_i \cdot K_j = \sum_{a=1}^{d_k} Q_{i,a} K_{j,a} \]

Two reasons for dot-product over alternatives:

  1. Cheap and parallel. All \(T^2\) scores are one matrix multiplication: \(Q K^\top \in \mathbb{R}^{T \times T}\). On a GPU, one big matmul is faster than \(T^2\) small operations or \(T\) MLPs.
  2. Empirically as good as additive scoring. The 2017 transformer paper shows this. For \(d_k\) above a few dozen, dot-product attention matches or beats additive attention, at lower cost.

The cost of choosing dot product is scale dependence — addressed in Piece 3.

Piece 2 — The full score matrix

Stack all queries as \(Q \in \mathbb{R}^{T \times d_k}\) and all keys as \(K \in \mathbb{R}^{T \times d_k}\). The pairwise scores are:

\[ S = Q K^\top \in \mathbb{R}^{T \times T}, \qquad S_{ij} = Q_i \cdot K_j \]

Each row of \(S\) is one query's scores against all keys. Each column is one key's scores from all queries.

Computational cost. - Memory: \(T^2\) floats. For \(T = 2048\) and fp32, that's 16 MiB per layer per head. With 24 layers and 16 heads, 6 GiB just for the scores. This is the bottleneck Flash Attention (Phase 27) attacks. - FLOPs: \(2 T^2 d_k\) for the matmul. Quadratic in \(T\) — the famous "quadratic attention".

For Phase 15, we don't try to be clever. We compute \(S\) as a dense matrix. Phase 27 will revisit.

Piece 3 — Why divide by \(\sqrt{d_k}\)

This is the most important piece of the derivation. The argument is variance-control.

Suppose \(Q_i\) and \(K_j\) are random vectors with i.i.d. components: \(Q_{i,a}, K_{j,a} \sim \mathcal{N}(0, 1)\), independent across \(a\) and across \(i, j\). Then

\[ \mathbb{E}[Q_i \cdot K_j] = \sum_a \mathbb{E}[Q_{i,a}] \mathbb{E}[K_{j,a}] = 0 \]
\[ \text{Var}(Q_i \cdot K_j) = \sum_a \text{Var}(Q_{i,a} K_{j,a}) = \sum_a \mathbb{E}[Q_{i,a}^2] \mathbb{E}[K_{j,a}^2] = \sum_a 1 = d_k \]

So \(Q_i \cdot K_j \sim \mathcal{N}(0, d_k)\) approximately. Standard deviation = \(\sqrt{d_k}\).

The problem: for \(d_k = 64\), scores have standard deviation \(8\). For \(d_k = 256\), std = \(16\). The largest score in a row of \(T\) such scores can easily be 3–4 standard deviations above the mean.

Now apply softmax. \(\text{softmax}([s_1, \ldots, s_T])_i = e^{s_i} / \sum_j e^{s_j}\). If one \(s_i\) is much larger than the rest, \(e^{s_i}\) dominates and the softmax output is nearly one-hot — one entry near 1, all others near 0.

Why this is bad: in the saturated regime, the gradient of softmax is nearly zero. The model can't learn. Specifically, \(\partial \text{softmax}_i / \partial s_j = \text{softmax}_i (\delta_{ij} - \text{softmax}_j)\). When the softmax is near one-hot, this product is near zero for all entries.

The fix: divide the scores by \(\sqrt{d_k}\) before softmax. Now

\[ \text{Var}\left(\frac{Q_i \cdot K_j}{\sqrt{d_k}}\right) = \frac{d_k}{d_k} = 1 \]

— variance is back to 1, regardless of \(d_k\). The softmax sees scores in a reasonable range, doesn't saturate, gradient flows. The scaling is independent of training data — it's a function of the architectural choice \(d_k\).

🇪🇸 Resumen variance argument: sin escalar, los scores tienen std \(\sqrt{d_k}\). Si \(d_k\) es grande, los scores son grandes, el softmax se satura, el gradiente muere. Dividir entre \(\sqrt{d_k}\) restaura std = 1 y mantiene el softmax en su régimen útil.

Sanity check: what if \(d_k\) is tiny?

For \(d_k = 2\), std of scores is \(\sqrt{2} \approx 1.4\). Softmax is fine without scaling. The fix is unnecessary at small \(d_k\), but it doesn't hurt — dividing by \(\sqrt{2}\) leaves softmax behavior almost unchanged.

For Phase 15 toy examples (\(d_k = 2\)), the scaling is barely visible. For Phase 17's Mini-GPT (\(d_k = 16\) per head with 4 heads, \(d_\text{model} = 64\)), it matters. For modern LLMs (\(d_k = 128+\) per head), it is essential.

Lab 00 verification

In lab/00-attention-by-hand.md, Borja will run attention twice on the same Q, K, V — once with scaling, once without — at \(d_k = 64\), and observe that the unscaled version's attention matrix is nearly one-hot. The visual is convincing.

Piece 4 — Softmax (with numerical stability)

The naive softmax:

\[ \text{softmax}(s_i) = \frac{e^{s_i}}{\sum_j e^{s_j}} \]

For large positive \(s\), \(e^s\) overflows fp32 (max around \(e^{88}\)). For our scaled scores this is rarely a problem, but in real implementations it eventually bites — especially during training when the gradient occasionally produces large values.

Numerical-stability rewrite:

\[ \text{softmax}(s_i) = \frac{e^{s_i - m}}{\sum_j e^{s_j - m}}, \qquad m = \max_k s_k \]

After subtracting the max, the largest value is \(0\), so \(e^{s_i - m} \leq 1\). No overflow. The output is identical to the naive form (the \(e^{-m}\) factor cancels between numerator and denominator).

Every production attention implementation does this max-subtraction. Phase 27 (Flash Attention) does it incrementally in tiles — same idea, more bookkeeping.

Implementation note for src/minimodel/attention/:

def softmax_stable(s, axis=-1):
    m = s.max(axis=axis, keepdims=True)
    exp = np.exp(s - m)
    return exp / exp.sum(axis=axis, keepdims=True)

Three lines. Always uses this form. Never the naive np.exp(s) / np.exp(s).sum().

Piece 5 — Multiply by V

The softmax outputs row-wise probabilities: \(A = \text{softmax}(S / \sqrt{d_k}) \in \mathbb{R}^{T \times T}\). Row \(i\) of \(A\) is a probability distribution over the \(T\) positions.

The final output is

\[ \text{Attention}(Q, K, V) = A V \in \mathbb{R}^{T \times d_v} \]

Row \(i\) of the output is the weighted average of value rows: \(\text{out}_i = \sum_j A_{ij} V_j\).

That's the entire forward pass. Six lines of code (literal LOC count below).

Putting it together

def single_head_attention(Q, K, V, mask=None):
    # Q, K: (T, d_k), V: (T, d_v)
    d_k = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)         # (T, T)
    if mask is not None:
        scores = scores + mask              # additive -inf, see file 04
    attn = softmax_stable(scores, axis=-1)  # (T, T) row-normalized
    return attn @ V                         # (T, d_v)

Five effective lines. This is the entire attention mechanism. The rest of the curriculum builds on this.

Backward pass (sketch)

You will not implement the backward by hand in Phase 15 — the autograd from Phase 8 handles it. But the gradient structure is worth knowing for the Phase 27 flash-attention preview:

Let \(L\) be the downstream loss, \(\delta_{\text{out}} = \partial L / \partial \text{out}\).

\[ \frac{\partial L}{\partial V} = A^\top \delta_{\text{out}}, \qquad \frac{\partial L}{\partial A} = \delta_{\text{out}} V^\top \]

Through softmax (standard derivation):

\[ \frac{\partial L}{\partial S} = A \odot \left( \frac{\partial L}{\partial A} - \left( \frac{\partial L}{\partial A} \odot A \right) \mathbf{1} \right) \]

(roughly; the exact form is in any autograd derivation reference.)

Then through the matmul \(S = Q K^\top / \sqrt{d_k}\):

\[ \frac{\partial L}{\partial Q} = \frac{1}{\sqrt{d_k}} \frac{\partial L}{\partial S} K, \qquad \frac{\partial L}{\partial K} = \frac{1}{\sqrt{d_k}} \left(\frac{\partial L}{\partial S}\right)^\top Q \]

Key observation: computing \(\partial L / \partial Q\) requires the full \(A\) matrix — \(O(T^2)\) memory. This is what Flash Attention recomputes in tiles to avoid materializing.

Don't implement this — autograd does. Just notice the \(O(T^2)\) memory requirement.

Complexity summary

Operation FLOPs Memory
\(Q K^\top\) \(2 T^2 d_k\) \(T^2\)
Scale \(T^2\)
Softmax \(T^2\) \(T^2\)
\(A V\) \(2 T^2 d_v\) \(T^2 + T d_v\)
Total \(\sim 4 T^2 d_k\) (with \(d_k = d_v\)) \(\sim T^2\)

For \(T = 256, d_k = 16\): FLOPs \(\approx 4 \cdot 65536 \cdot 16 \approx 4\) MFLOP. Trivial on the i5-8250U. For \(T = 2048, d_k = 64\): FLOPs \(\approx 4 \cdot 4 \cdot 10^6 \cdot 64 \approx 1\) GFLOP per layer per head. With 24 layers and 16 heads, ~400 GFLOP per forward pass. 2 seconds on the i5-8250U at 200 GFLOPS peak. That's why we wait for cloud GPU in Phase 23.

Roofline analysis (preview of Phase 27)

Arithmetic intensity of the attention matmul \(Q K^\top\):

  • FLOPs: \(2 T^2 d_k\).
  • Bytes: read \(Q\) (\(T d_k\) fp32 = \(4 T d_k\) bytes), read \(K\) (\(4 T d_k\)), write \(S\) (\(4 T^2\)). Total \(\approx 8 T d_k + 4 T^2\) bytes.
  • Intensity: \(\frac{2 T^2 d_k}{8 T d_k + 4 T^2} = \frac{T d_k}{4(d_k + T)}\).

For \(T \gg d_k\) (long sequences), intensity \(\approx d_k / 4\). Memory-bound on the i5-8250U if \(d_k < 40\) (since the machine balance is 10 FLOPs/byte from Phase 1 — but wait, 200 GFLOPS / 20 GB/s = 10 FLOPs/byte, so \(d_k = 40\) is the crossover). For typical \(d_k = 64\), compute-bound; for \(d_k < 16\) per head as in our Mini-GPT, memory-bound.

The softmax pass is always memory-bound (1–2 FLOPs per byte). This is the kernel Flash Attention attacks.

We don't fix this in Phase 15 — we just see it. Lab 03 measures.

🇪🇸 Para conectar con Fase 1: attention en secuencias cortas es compute-bound; en secuencias largas con cabezas pequeñas, memory-bound. La parte de softmax es siempre memory-bound. Esto motiva Flash Attention (Fase 27), que no cambia las FLOPs sino la cantidad de bytes movidos.

What this file does NOT cover

  • Multi-head extension. Next file (03-multi-head.md). Here we did single-head.
  • Causal masking. 04-masking.md. The mask parameter in the code above is teased but not derived here.
  • Backward pass implementation. Phase 8's autograd handles it; we sketched the math for context only.
  • Memory-efficient attention. Flash Attention is Phase 27. Phase 15 implements the naive \(O(T^2)\) form.
  • Bias terms on Q/K/V projections. GPT-2 style drops them; we follow the convention without re-derivation.
  • Alternative similarity functions. Bahdanau (additive), bilinear, learned MLP. Mentioned as historical context; we only implement dot-product.

What to memorize

Before lab, Borja should be able to write — from memory, on paper — the following in under 5 minutes:

  1. The full equation: \(\text{Attention}(Q, K, V) = \text{softmax}(Q K^\top / \sqrt{d_k}) V\).
  2. The shapes: \(Q, K \in \mathbb{R}^{T \times d_k}\), \(V \in \mathbb{R}^{T \times d_v}\), output \(\in \mathbb{R}^{T \times d_v}\).
  3. The variance argument for \(\sqrt{d_k}\) scaling. (Var of \(q \cdot k\) = \(d_k\) when components are \(\mathcal{N}(0, 1)\); scaling restores unit variance.)
  4. The max-subtraction trick for stable softmax.
  5. The five-line NumPy implementation.

The /quiz 15 set checks these.


Next: 03-multi-head.md.