Skip to content

English · Español

Lab 00 — Predict shapes from einsum strings

Goal: make einsum shape-arithmetic mechanical. No code in Part A; just paper. Examples are anchored on §A13 verb-form encodings.

Estimated time: 60–90 minutes.

Prereq: theory 01-tensors-and-shapes.md and 02-matmul-and-shapes.md read.


What you produce

A directory experiments/03-shapes-by-hand/ containing:

  • predictions.md — your hand-written shape predictions for every problem below.
  • verify.py — short script that constructs the operands and runs the einsums, comparing to your predictions.
  • results.json — pass/fail per prediction.
  • manifest.json.

The §A13 dimension constants

For every problem, assume these standard sizes (per theory/01-tensors-and-shapes.md):

B   = 32      # batch
T   = 16      # sequence length
V   = 600     # vocabulary (§A13: 20 verbs × 5 tenses × 3 persons + Spanish pairs)
D   = 64      # embedding dim
H   = 4       # number of attention heads
D_k = 16      # per-head dim = D / H
D_ff= 256     # FFN intermediate
K_classes = 5 # number of tense classes

TODOs

Part A — predict on paper

For each einsum expression below, write in predictions.md:

  • The shape of every operand.
  • The shape of the output.
  • The total number of multiply-add FLOPs (2 × product_of_all_indices).
  • A one-sentence English description of what the operation does in the §A13 context.

No code yet. Solve by reading the einsum string and applying the two rules (repeated = summed, free = output).


1. Embedding lookup (single token).

einsum('v,vd->d', one_hot, E)

one_hot.shape = (V,), E.shape = (V, D).


2. Batched embedding lookup.

einsum('btv,vd->btd', tokens_one_hot, E)

tokens_one_hot.shape = (B, T, V), E.shape = (V, D).


3. Tense classification.

einsum('bd,kd->bk', hidden, W_tense)

hidden.shape = (B, D), W_tense.shape = (K_classes, D).


4. Per-token tense classification (batched + sequential).

einsum('btd,kd->btk', x, W_tense)

x.shape = (B, T, D), W_tense.shape = (K_classes, D).


5. Linear projection (Q in attention).

einsum('btd,de->bte', x, W_Q)

x.shape = (B, T, D), W_Q.shape = (D, D).


6. Reshape for multi-head — Q, K, V split.

After computing Q of shape (B, T, D), you reshape to (B, T, H, D_k) then transpose to (B, H, T, D_k). Write the einsum that goes from (B, T, D) to (B, H, T, D_k) directly. (Hint: think of it as a no-op contraction with appropriate reshape.) Actually einsum can't do reshapes alone; instead, predict the shapes after Q.reshape(B, T, H, D_k).transpose(0, 2, 1, 3).


7. Attention scores.

einsum('bhqd,bhkd->bhqk', Q, K)

Q.shape = (B, H, T, D_k), K.shape = (B, H, T, D_k).


8. Attention output.

einsum('bhqk,bhkd->bhqd', attn_probs, V)

attn_probs.shape = (B, H, T, T), V.shape = (B, H, T, D_k).


9. Output projection.

einsum('bhtd,hde->bte', attn_out, W_O)

attn_out.shape = (B, H, T, D_k), W_O.shape = (H, D_k, D).


10. FFN expansion.

einsum('btd,df->btf', x, W_1)

x.shape = (B, T, D), W_1.shape = (D, D_ff).


11. FFN contraction.

einsum('btf,fd->btd', h, W_2)

h.shape = (B, T, D_ff), W_2.shape = (D_ff, D).


12. Vocabulary projection (final layer).

einsum('btd,vd->btv', x, E)

x.shape = (B, T, D), E.shape = (V, D). (Note E here is tied with the input embedding — same matrix.)


13. Cross-entropy log-likelihood reduction.

einsum('btv,btv->', log_probs, labels_one_hot)

log_probs.shape = (B, T, V), labels_one_hot.shape = (B, T, V).


14. Per-sequence average log-likelihood.

einsum('btv,btv->b', log_probs, labels_one_hot)

(Then divide by T.) What's the result shape?


15. Diagonal of a square matrix.

einsum('ii->i', M)

M.shape = (5, 5).


16. Trace.

einsum('ii->', M)

M.shape = (5, 5).


17. Frobenius inner product.

einsum('mn,mn->', A, B)

A.shape = (20, 15), B.shape = (20, 15).


18. Outer product of two §A13 verb-form vectors.

einsum('v,w->vw', a, b)

a.shape = (V,), b.shape = (V,). What's the size in MB at fp32?


19. Mixed batched contraction.

einsum('btd,dvf->btvf', x, T)

x.shape = (B, T, D), T.shape = (D, V, D_ff). (Unusual; just for shape practice.)


20. The §A13 conjugation-count matrix dot.

einsum('vp,wp->vw', C, C)

C.shape = (20, 15) (20 verbs × 15 conjugation indices). What is C @ C^T computing in §A13 terms?

Part B — verify with code

verify.py: for each of the 20 expressions above, construct random operands with the specified shapes (use np.random.default_rng(42).standard_normal(shape).astype(np.float32)), run the einsum, print the actual shape, and check against your prediction.

predictions = {
    1: (D,),
    2: (B, T, D),
    3: (B, K_classes),
    # ...
}

for expr_id, expected_shape in predictions.items():
    # construct operands
    # run einsum
    # check shape
    pass_fail = (actual_shape == expected_shape)
    print(f"{expr_id}: predicted {expected_shape}, got {actual_shape}, {'PASS' if pass_fail else 'FAIL'}")

Save results to results.json. You must get 20/20. A miss means re-derive on paper before re-running.

Part C — FLOPs verification

For three of the expressions (your choice — try 2, 7, 12), compute the theoretical FLOPs (write the formula in predictions.md). Compare to the measured time × your machine's peak GFLOPS (from Phase 1's roofline). The measured time may be much higher than theoretical predicts because of Python overhead.

Part D — the killer question

Expression 12 ('btd,vd->btv') and expression 2 ('btv,vd->btd') look almost like inverses of each other. Are they? In §A13 terms, expression 2 is "look up embedding for each token"; expression 12 is "project hidden state back to vocabulary logits". What property of the embedding matrix E would make them genuine inverses? (Hint: it relates to orthogonality.) Discuss in predictions.md.

Constraints

  • Predict before running. Part A → Part B, in order.
  • One-sentence English description for each. Forces you to think in §A13 terms, not just shape arithmetic.
  • np.einsum is the only allowed implementation — even for operations that have specialized NumPy functions (np.dot, np.matmul). The point is practicing einsum.

Stop conditions

Done when:

  1. predictions.md has all 20 predictions with shape + FLOPs + English description.
  2. verify.py prints 20/20 PASS.
  3. Part D's killer question has a written answer.
  4. You can read any new einsum string and predict its shape without consulting notes.

Pitfalls

  • Missing index in output. If you write 'btv,vd->btv', the d is unmatched in the output — invalid einsum. NumPy raises an error.
  • Inconsistent dimension in different operands. If tokens_one_hot.shape = (B, T, 599) instead of (B, T, V=600), the v in E (size 600) won't match. NumPy raises an error.
  • The trailing ->. If you omit the right-hand side, numpy uses an implicit convention (sum any axis that doesn't appear in the right-hand side, keep alphabetical for unrepeated). Be explicit always.

When to consult solutions/

After committing all four files. Solution at solutions/00-shapes-by-hand-ref.md (written at phase open).


Next lab: lab/01-matmul-perf.md.