Skip to content

English · Español

Lab 01 — Multi-Head Attention

Goal: extend single-head attention to multi-head, lock the MultiHeadAttention class API, and verify that single-head with \(H = 1\) matches the previous lab's output exactly.

Estimated time: 90–120 minutes.

Prereq: lab 00 committed.


What you produce

A directory experiments/15-multi-head/ containing:

  • mha.py — your NumPy multi-head implementation, importing from src/minimodel/attention/attention.py.
  • verify.py — verification script.
  • verify_output.txt — captured printout.
  • heatmap.png — 4-panel figure: attention pattern for each of \(H = 4\) heads on a fixed input.
  • manifest.json.
  • README.md.

Background

theory/03-multi-head.md covers: - The split-and-stack construction. - Parameter count: \(4 d_\text{model}^2\) vs \(3 d_\text{model}^2\) for single-head. - Why the output projection \(W_O\) is essential. - The "one big matrix per role, reshape at runtime" implementation trick.

src/minimodel/attention/BLUEPRINT.md (read it!) locks the class API:

class MultiHeadAttention:
    def __init__(self, d_model: int, n_heads: int, seed: int = 0) -> None: ...
    def forward(self, x: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray: ...

The class owns four weight matrices: W_Q, W_K, W_V, W_O, each \(d_\text{model} \times d_\text{model}\).

TODOs

Block A — implement the class

  • In src/minimodel/attention/attention.py, implement MultiHeadAttention.
  • In __init__: allocate the four matrices using np.random.default_rng(seed). Scale by 1 / sqrt(d_model) (Phase 10 init).
  • In forward:
  • Compute Q = x @ self.W_Q, similarly for K, V. Shape (T, d_model).
  • Reshape each to (T, n_heads, d_head), transpose to (n_heads, T, d_head).
  • For each head independently (or with einsum / batched matmul):
    • scores = Q_h @ K_h.T / sqrt(d_head) — shape (T, T).
    • Apply mask if given (additive).
    • attn = softmax(scores).
    • out_h = attn @ V_h — shape (T, d_head).
  • Reshape and concatenate heads back to (T, d_model).
  • Apply out @ self.W_O. Return.
  • Aim for ≤ 30 LOC in the forward. Use einsum if it helps readability — but a for h in range(H) loop is also fine for clarity.

Block B — verify single-head equivalence

When \(H = 1\), the multi-head class should behave exactly like the single-head function from lab 00 (up to the output projection).

  • Construct a MultiHeadAttention(d_model=4, n_heads=1, seed=42).
  • Manually extract its W_Q, W_K, W_V (shape (4, 4)) — call single_head_attention(X @ W_Q, X @ W_K, X @ W_V) from lab 00.
  • Then apply W_O to that result.
  • Compare to mha.forward(X).
  • The two must agree to 1e-5. Assert.

Block C — explore: head specialization

For a fixed input — the canonical 8-token verb-grammar sequence <bos> I work , you work , he (use the Phase 14 lab 00 tokenization):

  • Build MultiHeadAttention(d_model=64, n_heads=4, seed=0).
  • Run forward. Capture the attention matrices attn_h for each head (modify forward to optionally return them, or save them as a side effect).
  • Plot 4 heatmaps in a 2×2 grid. Each heatmap is \(T \times T\), with attention probability as color (use viridis).
  • Annotate axes with the actual decoded tokens.
  • Save as heatmap.png.

With random weights, the heads are random patterns — that's fine. The point is shape, not semantics. (Trained attention patterns appear in Phase 18.)

Block D — write up

In README.md:

  1. Confirm the single-head equivalence (Block B). State the max diff.
  2. Describe the heatmap shapes (Block C). Are the four heads distinguishable from each other? With random weights, they should look different (random is different from random). Note any patterns you see (probably none — that's the expected null result for random init).

Block E — manifest

{
  "experiment": "15-multi-head",
  "date": "YYYY-MM-DD",
  "seed": 42,
  "versions": { "python": "3.11.x", "numpy": "X.Y.Z", "matplotlib": "X.Y.Z" },
  "config": {
    "d_model": 64,
    "n_heads": 4,
    "T": 8,
    "input_snippet": "<bos> I work , you work , he"
  },
  "results_summary": {
    "single_head_equivalence_max_diff": null,
    "heads_visibly_distinct": null
  }
}

Constraints

  • No PyTorch.
  • Reshape, don't loop, where possible. Both work; reshape is faster and is what production code does. Loop is allowed for clarity in your first pass.
  • Mask is None in this lab. Lab 02 adds the causal mask.

Stop conditions

Done when:

  1. All six files committed.
  2. Single-head equivalence assertion passes (max_diff < 1e-5).
  3. Heatmap shows four visibly different patterns (even if structureless).
  4. README.md answers both Block D questions.

Pitfalls

  • Reshape order matters. x.reshape(T, H, d_head) is different from x.reshape(T, d_head, H). The first puts each head's features adjacent in memory; the second does not. Use the former.
  • Transpose for matmul. After reshaping to (T, H, d_head), you need (H, T, d_head) for batched matmul. Use .transpose(1, 0, 2).
  • Bias of W_Q, W_K, W_V? Phase 15 uses no bias in these projections (standard for transformers since 2017). \(W_O\) also has no bias. Document this in README.md.
  • Don't forget \(W_O\). Concatenating heads is not the end. The output projection is essential.
  • Verifying with random weights. The patterns won't look like trained attention. That's expected.

When to consult solutions/

After all six files committed and Block B assertion passes. Solution at solutions/01-multi-head-ref.md.


Next lab: 02-causal-mask.md.