English · Español
Lab 00 — Predict shapes from einsum strings¶
Goal: make einsum shape-arithmetic mechanical. No code in Part A; just paper. Examples are anchored on §A13 verb-form encodings.
Estimated time: 60–90 minutes.
Prereq: theory
01-tensors-and-shapes.mdand02-matmul-and-shapes.mdread.
What you produce¶
A directory experiments/03-shapes-by-hand/ containing:
predictions.md— your hand-written shape predictions for every problem below.verify.py— short script that constructs the operands and runs the einsums, comparing to your predictions.results.json— pass/fail per prediction.manifest.json.
The §A13 dimension constants¶
For every problem, assume these standard sizes (per theory/01-tensors-and-shapes.md):
B = 32 # batch
T = 16 # sequence length
V = 600 # vocabulary (§A13: 20 verbs × 5 tenses × 3 persons + Spanish pairs)
D = 64 # embedding dim
H = 4 # number of attention heads
D_k = 16 # per-head dim = D / H
D_ff= 256 # FFN intermediate
K_classes = 5 # number of tense classes
TODOs¶
Part A — predict on paper¶
For each einsum expression below, write in predictions.md:
- The shape of every operand.
- The shape of the output.
- The total number of multiply-add FLOPs (
2 × product_of_all_indices). - A one-sentence English description of what the operation does in the §A13 context.
No code yet. Solve by reading the einsum string and applying the two rules (repeated = summed, free = output).
1. Embedding lookup (single token).
one_hot.shape = (V,), E.shape = (V, D).
2. Batched embedding lookup.
tokens_one_hot.shape = (B, T, V), E.shape = (V, D).
3. Tense classification.
hidden.shape = (B, D), W_tense.shape = (K_classes, D).
4. Per-token tense classification (batched + sequential).
x.shape = (B, T, D), W_tense.shape = (K_classes, D).
5. Linear projection (Q in attention).
x.shape = (B, T, D), W_Q.shape = (D, D).
6. Reshape for multi-head — Q, K, V split.
After computing Q of shape (B, T, D), you reshape to (B, T, H, D_k) then transpose to (B, H, T, D_k). Write the einsum that goes from (B, T, D) to (B, H, T, D_k) directly. (Hint: think of it as a no-op contraction with appropriate reshape.) Actually einsum can't do reshapes alone; instead, predict the shapes after Q.reshape(B, T, H, D_k).transpose(0, 2, 1, 3).
7. Attention scores.
Q.shape = (B, H, T, D_k), K.shape = (B, H, T, D_k).
8. Attention output.
attn_probs.shape = (B, H, T, T), V.shape = (B, H, T, D_k).
9. Output projection.
attn_out.shape = (B, H, T, D_k), W_O.shape = (H, D_k, D).
10. FFN expansion.
x.shape = (B, T, D), W_1.shape = (D, D_ff).
11. FFN contraction.
h.shape = (B, T, D_ff), W_2.shape = (D_ff, D).
12. Vocabulary projection (final layer).
x.shape = (B, T, D), E.shape = (V, D). (Note E here is tied with the input embedding — same matrix.)
13. Cross-entropy log-likelihood reduction.
log_probs.shape = (B, T, V), labels_one_hot.shape = (B, T, V).
14. Per-sequence average log-likelihood.
(Then divide by T.) What's the result shape?
15. Diagonal of a square matrix.
M.shape = (5, 5).
16. Trace.
M.shape = (5, 5).
17. Frobenius inner product.
A.shape = (20, 15), B.shape = (20, 15).
18. Outer product of two §A13 verb-form vectors.
a.shape = (V,), b.shape = (V,). What's the size in MB at fp32?
19. Mixed batched contraction.
x.shape = (B, T, D), T.shape = (D, V, D_ff). (Unusual; just for shape practice.)
20. The §A13 conjugation-count matrix dot.
C.shape = (20, 15) (20 verbs × 15 conjugation indices). What is C @ C^T computing in §A13 terms?
Part B — verify with code¶
verify.py: for each of the 20 expressions above, construct random operands with the specified shapes (use np.random.default_rng(42).standard_normal(shape).astype(np.float32)), run the einsum, print the actual shape, and check against your prediction.
predictions = {
1: (D,),
2: (B, T, D),
3: (B, K_classes),
# ...
}
for expr_id, expected_shape in predictions.items():
# construct operands
# run einsum
# check shape
pass_fail = (actual_shape == expected_shape)
print(f"{expr_id}: predicted {expected_shape}, got {actual_shape}, {'PASS' if pass_fail else 'FAIL'}")
Save results to results.json. You must get 20/20. A miss means re-derive on paper before re-running.
Part C — FLOPs verification¶
For three of the expressions (your choice — try 2, 7, 12), compute the theoretical FLOPs (write the formula in predictions.md). Compare to the measured time × your machine's peak GFLOPS (from Phase 1's roofline). The measured time may be much higher than theoretical predicts because of Python overhead.
Part D — the killer question¶
Expression 12 ('btd,vd->btv') and expression 2 ('btv,vd->btd') look almost like inverses of each other. Are they? In §A13 terms, expression 2 is "look up embedding for each token"; expression 12 is "project hidden state back to vocabulary logits". What property of the embedding matrix E would make them genuine inverses? (Hint: it relates to orthogonality.) Discuss in predictions.md.
Constraints¶
- Predict before running. Part A → Part B, in order.
- One-sentence English description for each. Forces you to think in §A13 terms, not just shape arithmetic.
np.einsumis the only allowed implementation — even for operations that have specialized NumPy functions (np.dot,np.matmul). The point is practicing einsum.
Stop conditions¶
Done when:
predictions.mdhas all 20 predictions with shape + FLOPs + English description.verify.pyprints 20/20 PASS.- Part D's killer question has a written answer.
- You can read any new einsum string and predict its shape without consulting notes.
Pitfalls¶
- Missing index in output. If you write
'btv,vd->btv', thedis unmatched in the output — invalid einsum. NumPy raises an error. - Inconsistent dimension in different operands. If
tokens_one_hot.shape = (B, T, 599)instead of(B, T, V=600), thevinE(size 600) won't match. NumPy raises an error. - The trailing
->. If you omit the right-hand side, numpy uses an implicit convention (sum any axis that doesn't appear in the right-hand side, keep alphabetical for unrepeated). Be explicit always.
When to consult solutions/¶
After committing all four files. Solution at solutions/00-shapes-by-hand-ref.md (written at phase open).
Next lab: lab/01-matmul-perf.md.