Skip to content

English · Español

02 — Matrix multiplication as composition

🇪🇸 Matmul es la composición de dos mapas lineales: (AB)x = A(Bx). Esa definición — y no "el ratio O(N³)" — explica por qué hay batched matmul, por qué multi-head attention es paralelizable, y por qué la regla de las dimensiones internas funciona. Ejemplo central: E @ one_hot(i) es una búsqueda en una tabla de embeddings (Fase 13).

This is the theory page of Phase 3. Re-derive every diagram below until you can sketch it on a napkin.


Three views of matmul

View 1 — sum of products

The textbook definition. For A of shape (M, K) and B of shape (K, N):

\[ C[i, j] = \sum_{k=0}^{K-1} A[i, k] \cdot B[k, j] \]

C has shape (M, N). Total scalar multiplies: M × K × N. Total scalar adds: M × (K - 1) × N. So FLOPs ≈ 2 × M × K × N.

This is the view you implement when you write naive matmul (three nested loops). It's correct, and irrelevant for understanding.

View 2 — composition of linear maps

A matrix A of shape (M, K) represents a linear map f: R^K → R^M. A matrix B of shape (K, N) represents g: R^N → R^K. Their product AB represents the composition f ∘ g: R^N → R^M. The shape rule (M, K) @ (K, N) = (M, N) is literally the function composition rule "the input dimension of f (= K) must match the output dimension of g (= K)."

In code:

# f: R^K -> R^M
A.shape == (M, K)
# g: R^N -> R^K
B.shape == (K, N)
# f ∘ g: R^N -> R^M
(A @ B).shape == (M, N)

Apply to a vector x of shape (N,):

  • B @ x is g(x), shape (K,).
  • A @ (B @ x) is f(g(x)), shape (M,).
  • (A @ B) @ x is the same thing by associativity, computed in a different order.

This is the view that explains why matmul is the central operation. Every layer in a neural network is "apply a linear map (then a non-linearity)". Composing layers is composing linear maps is matmul.

View 3 — sum of outer products

AB can be written as a sum of K outer products:

\[ AB = \sum_{k=0}^{K-1} A[:, k] \otimes B[k, :] \]

Where A[:, k] is the k-th column of A (shape (M, 1)) and B[k, :] is the k-th row of B (shape (1, N)); their outer product is (M, N). Sum K of them.

This view matters for understanding low-rank approximations and LoRA: if you keep only the largest r of these outer products (selected by SVD), you get the best rank-r approximation to AB.

The §A13 embedding lookup, revisited

Recall the embedding matrix E of shape (V, D) where V = 600. One-hot vector e_i (length V, all zeros except a 1 at position i).

result = E.T @ e_i   # shape (D, V) @ (V,) = (D,)
# or equivalently:
result = np.einsum('vd,v->d', E, e_i)
# or equivalently:
result = E[i]        # direct indexing

All three return the i-th row of E — the embedding of verb form i. They differ in how much work the machine does:

  • E.T @ e_i does V × D multiplies, of which (V-1) × D are multiplies by zero. Total FLOPs: 2 V D = 2 × 600 × 64 = 76,800.
  • E[i] does D memory loads, zero multiplies. Total FLOPs: 0.

The fast path is the indexing — embedding tables are accessed by lookup, not by matmul. But the mathematical interpretation is matrix-vector multiplication with a one-hot. Phase 13 explores the lookup table view; Phase 17 (MiniGPT) actually uses indexing.

Why this distinction matters. Some hardware (older TPUs, the original NPU designs) do not have efficient gather operations. They implement embedding lookup as a one-hot matmul. The math is identical; the performance is not. Knowing both views lets you reason about what the hardware is actually doing.

Batched matmul

For tensors with leading batch axes, matmul applies to the last two dimensions, broadcasting the leading axes:

Input A Input B Output
(M, K) (K, N) (M, N)
(B, M, K) (B, K, N) (B, M, N)
(B, M, K) (K, N) (B, M, N) (B broadcast)
(M, K) (B, K, N) (B, M, N) (M broadcast over batch)
(B, H, M, K) (B, H, K, N) (B, H, M, N)

The last row is the shape of multi-head attention's main matmul: B batches × H heads × M queries × K keys.

Cost is exactly the same per "matmul instance" — 2 M K N FLOPs each — multiplied by the number of batch elements. Total cost for (B, H, M, K) @ (B, H, K, N) is 2 B H M K N FLOPs. Memorize this.

Reading shapes from code

Three exercises in shape inference:

Example 1 — a transformer FFN block

x.shape = (B, T, D)           # input activations
W_1.shape = (D, D_ff)         # FFN expansion
W_2.shape = (D_ff, D)         # FFN contraction

h = x @ W_1                   # shape: (B, T, D_ff)
h = np.maximum(h, 0)          # ReLU; shape unchanged
y = h @ W_2                   # shape: (B, T, D)

Each matmul broadcasts over (B, T) automatically. FLOPs total: B × T × (D × D_ff + D_ff × D) × 2 = 4 × B × T × D × D_ff.

For Borja's MiniGPT (B=32, T=16, D=64, D_ff=256): 4 × 32 × 16 × 64 × 256 ≈ 33M FLOPs per FFN per forward pass.

Example 2 — attention head

Q.shape = (B, H, T, D_k)
K.shape = (B, H, T, D_k)
V.shape = (B, H, T, D_k)

scores = np.einsum('bhqd,bhkd->bhqk', Q, K)   # shape (B, H, T, T)
scores = scores / np.sqrt(D_k)
attn = stable_softmax(scores, axis=-1)        # shape (B, H, T, T)
output = np.einsum('bhqk,bhkd->bhqd', attn, V) # shape (B, H, T, D_k)

Two einsums, both 4D batched matmuls. Total FLOPs: 2 × 2 × B × H × T × T × D_k = 4 B H T² D_k. For Borja's MiniGPT (B=32, H=4, T=16, D_k=16): 4 × 32 × 4 × 256 × 16 ≈ 2.1M FLOPs per attention block.

You'll see these expressions repeatedly. Phase 3's job is to make them mechanical.

Example 3 — the §A13 tense classifier

hidden.shape = (B, D)         # post-attention hidden state for batch
W_tense.shape = (D, 5)        # classifier weights for 5 tenses

logits = hidden @ W_tense     # shape: (B, 5)
probs = stable_softmax(logits, axis=-1)   # shape: (B, 5)

B × 5 = 160 predictions of "which tense?" each forward pass. The CE loss against true tense labels drives training.

Special operations

Dot product

np.dot(a, b) for 1-D a, b of shape (N,) returns a scalar. Equivalent to np.einsum('i,i->', a, b). FLOPs: 2N - 1.

Outer product

np.outer(a, b) returns shape (M, N). Equivalent to np.einsum('i,j->ij', a, b). FLOPs: M × N (no additions).

Hadamard (element-wise)

a * b returns shape (N,) (or broadcasted shape). NOT a matmul. FLOPs: N.

Common bug: writing A * B when you meant A @ B. The first is element-wise (requires same shape); the second is matmul (requires (M, K) × (K, N)). NumPy's error messages distinguish, but PyTorch sometimes silently broadcasts in confusing ways. Use the shape-comment habit.

Performance — the gap your eyes will see in lab 01

In Python, naive triple-loop matmul over fp32 arrays of size (1024, 1024) × (1024, 1024) takes ~minutes. np.matmul of the same arrays takes ~10 ms. The gap is 10⁴-10⁵×, much wider than the "50×" Phase 1 predicted.

Where does the gap come from?

  1. Python interpreter overhead. A for k in range(K) loop in Python is ~100 ns per iteration just for the bytecode. Naive matmul does M × K × N = 10⁹ iterations, so ~100 s just for the loop overhead.
  2. No SIMD. np.matmul uses AVX2 (8 fp32 multiplies per instruction). Triple-loop does 1.
  3. No cache blocking. np.matmul blocks for L1/L2, raising arithmetic intensity to ~100 FLOPs/byte. Naive matmul is at the 0.25 floor (theory 03-roofline-model.md in Phase 1).
  4. No multi-threading. OpenBLAS uses all 4 cores. Naive uses 1.

Compound effect: 100 × 8 × 40 × 4 ≈ 100,000×. That's roughly what you'll see.

Phase 6 (Python for AI Engineering) covers (1). Phase 3's lab 01 just makes you see the gap and points at the cause. The conclusion: always vectorize through NumPy/BLAS; never write inner loops in Python.

Tying it together — the einsum cheatsheet

For Borja's reference, the most common einsum patterns in this curriculum:

Operation Einsum Result shape
(M, K) @ (K, N) 'ij,jk->ik' (M, N)
Batched matmul (B, M, K) @ (B, K, N) 'bij,bjk->bik' (B, M, N)
E^T @ one_hot(i) (embedding lookup) 'vd,v->d' (D,)
Batched embedding 'btv,vd->btd' (B, T, D)
Attention scores Q @ K^T 'bhqd,bhkd->bhqk' (B, H, T, T)
Attention output attn @ V 'bhqk,bhkd->bhqd' (B, H, T, D_k)
Frobenius inner product 'ij,ij->' ()
Trace 'ii->' ()
Diagonal 'ii->i' (N,)

Memorize the first six. The rest are derivable.

Drill problems

Solutions in solutions/02-matmul-and-shapes-ref.md (phase-open).

  1. Given A.shape = (B, H, T, D_k) and B.shape = (B, H, T, D_k), write the einsum that computes the per-head dot product over the last axis — i.e., shape (B, H, T). What is the operation called in attention?
  2. The §A13 vocabulary has 600 verb forms. A small classifier has weight matrix W.shape = (5, 600) (5 tenses). Write the einsum that, given a one-hot encoding of a token (V,), produces the 5-tense logit vector (5,). (Yes, this is just W @ one_hot, written in einsum.)
  3. Prove that the einsum 'ik,kj->ij' is associative: for three matrices A, B, C, show that (AB)C = A(BC) by writing out the indices.
  4. FLOPs for the attention block of Borja's MiniGPT (B=32, H=4, T=16, D=64, D_k=16). Sum: Q/K/V projections + attention scores + attention output + output projection. Compare to FFN FLOPs.
  5. Why is np.matmul faster than for k in range(K): C += A[:, k:k+1] @ B[k:k+1, :]? Both compute the same outer-product sum.

One-paragraph recap

Matmul (M, K) @ (K, N) = (M, N) is the composition of linear maps; the inner-dimension rule is the function-composition signature constraint. It can equivalently be read as a sum of K outer products of column-of-A times row-of-B, which is the basis of low-rank approximations. Batched matmul broadcasts the leading axes; multi-head attention is one such batched matmul. The performance gap between naive Python matmul and np.matmul is 10⁴-10⁵× and comes from interpreter overhead, missing SIMD, missing cache blocking, and missing parallelism. Master einsum as the unifying grammar and your code will be type-safe by construction.

What this page does NOT cover

  • Numerical precision of matmul (Phase 2; tests use rtol=1e-5).
  • Gradient through matmul (Phase 4 + 8).
  • Sparse matmul (out of scope).
  • GPU GEMM kernel internals (Phase 24).

Next: theory/03-svd-and-rank.md.