English · Español
Lab 01 — Multi-Head Attention¶
Goal: extend single-head attention to multi-head, lock the
MultiHeadAttentionclass API, and verify that single-head with \(H = 1\) matches the previous lab's output exactly.Estimated time: 90–120 minutes.
Prereq: lab 00 committed.
What you produce¶
A directory experiments/15-multi-head/ containing:
mha.py— your NumPy multi-head implementation, importing fromsrc/minimodel/attention/attention.py.verify.py— verification script.verify_output.txt— captured printout.heatmap.png— 4-panel figure: attention pattern for each of \(H = 4\) heads on a fixed input.manifest.json.README.md.
Background¶
theory/03-multi-head.md covers:
- The split-and-stack construction.
- Parameter count: \(4 d_\text{model}^2\) vs \(3 d_\text{model}^2\) for single-head.
- Why the output projection \(W_O\) is essential.
- The "one big matrix per role, reshape at runtime" implementation trick.
src/minimodel/attention/BLUEPRINT.md (read it!) locks the class API:
class MultiHeadAttention:
def __init__(self, d_model: int, n_heads: int, seed: int = 0) -> None: ...
def forward(self, x: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray: ...
The class owns four weight matrices: W_Q, W_K, W_V, W_O, each \(d_\text{model} \times d_\text{model}\).
TODOs¶
Block A — implement the class¶
- In
src/minimodel/attention/attention.py, implementMultiHeadAttention. - In
__init__: allocate the four matrices usingnp.random.default_rng(seed). Scale by1 / sqrt(d_model)(Phase 10 init). - In
forward: - Compute
Q = x @ self.W_Q, similarly for K, V. Shape(T, d_model). - Reshape each to
(T, n_heads, d_head), transpose to(n_heads, T, d_head). - For each head independently (or with einsum / batched matmul):
scores = Q_h @ K_h.T / sqrt(d_head)— shape(T, T).- Apply mask if given (additive).
attn = softmax(scores).out_h = attn @ V_h— shape(T, d_head).
- Reshape and concatenate heads back to
(T, d_model). - Apply
out @ self.W_O. Return. - Aim for ≤ 30 LOC in the forward. Use
einsumif it helps readability — but afor h in range(H)loop is also fine for clarity.
Block B — verify single-head equivalence¶
When \(H = 1\), the multi-head class should behave exactly like the single-head function from lab 00 (up to the output projection).
- Construct a
MultiHeadAttention(d_model=4, n_heads=1, seed=42). - Manually extract its
W_Q, W_K, W_V(shape(4, 4)) — callsingle_head_attention(X @ W_Q, X @ W_K, X @ W_V)from lab 00. - Then apply
W_Oto that result. - Compare to
mha.forward(X). - The two must agree to 1e-5. Assert.
Block C — explore: head specialization¶
For a fixed input — the canonical 8-token verb-grammar sequence <bos> I work , you work , he (use the Phase 14 lab 00 tokenization):
- Build
MultiHeadAttention(d_model=64, n_heads=4, seed=0). - Run forward. Capture the attention matrices
attn_hfor each head (modify forward to optionally return them, or save them as a side effect). - Plot 4 heatmaps in a 2×2 grid. Each heatmap is \(T \times T\), with attention probability as color (use
viridis). - Annotate axes with the actual decoded tokens.
- Save as
heatmap.png.
With random weights, the heads are random patterns — that's fine. The point is shape, not semantics. (Trained attention patterns appear in Phase 18.)
Block D — write up¶
In README.md:
- Confirm the single-head equivalence (Block B). State the max diff.
- Describe the heatmap shapes (Block C). Are the four heads distinguishable from each other? With random weights, they should look different (random is different from random). Note any patterns you see (probably none — that's the expected null result for random init).
Block E — manifest¶
{
"experiment": "15-multi-head",
"date": "YYYY-MM-DD",
"seed": 42,
"versions": { "python": "3.11.x", "numpy": "X.Y.Z", "matplotlib": "X.Y.Z" },
"config": {
"d_model": 64,
"n_heads": 4,
"T": 8,
"input_snippet": "<bos> I work , you work , he"
},
"results_summary": {
"single_head_equivalence_max_diff": null,
"heads_visibly_distinct": null
}
}
Constraints¶
- No PyTorch.
- Reshape, don't loop, where possible. Both work; reshape is faster and is what production code does. Loop is allowed for clarity in your first pass.
- Mask is
Nonein this lab. Lab 02 adds the causal mask.
Stop conditions¶
Done when:
- All six files committed.
- Single-head equivalence assertion passes (
max_diff < 1e-5). - Heatmap shows four visibly different patterns (even if structureless).
README.mdanswers both Block D questions.
Pitfalls¶
- Reshape order matters.
x.reshape(T, H, d_head)is different fromx.reshape(T, d_head, H). The first puts each head's features adjacent in memory; the second does not. Use the former. - Transpose for matmul. After reshaping to
(T, H, d_head), you need(H, T, d_head)for batched matmul. Use.transpose(1, 0, 2). - Bias of
W_Q,W_K,W_V? Phase 15 uses no bias in these projections (standard for transformers since 2017). \(W_O\) also has no bias. Document this inREADME.md. - Don't forget \(W_O\). Concatenating heads is not the end. The output projection is essential.
- Verifying with random weights. The patterns won't look like trained attention. That's expected.
When to consult solutions/¶
After all six files committed and Block B assertion passes. Solution at solutions/01-multi-head-ref.md.
Next lab: 02-causal-mask.md.