Skip to content

English · Español

Break — Swap the contraction axis in a matmul

🇪🇸 Rompe el matmul invirtiendo qué eje se contrae. Las formas siguen cuadrando pero los números son basura. Esta clase de bug es invisible en CI hasta que las pérdidas no convergen — perfecto para una lección de "verifica con un caso conocido".

Target: any handwritten matmul (in numpy or pure-python loops) from lab/01-matmul-perf.md, or a fresh naive_matmul.

Hypothesis

The learner predicts: "Swapping np.einsum('ik,kj->ij', A, B) to np.einsum('ki,kj->ij', A, B) will silently produce a valid-shaped but numerically wrong result. Tests that only check the output shape will pass; tests that check actual values will fail."

The break

In your matmul wrapper:

 def matmul(A: np.ndarray, B: np.ndarray) -> np.ndarray:
-    return np.einsum('ik,kj->ij', A, B)
+    return np.einsum('ki,kj->ij', A, B)

Equivalently in loop form:

 for i in range(M):
   for j in range(N):
     for k in range(K):
-      C[i, j] += A[i, k] * B[k, j]
+      C[i, j] += A[k, i] * B[k, j]   # /break: contracted A's wrong axis

Run procedure

Use a known-answer test case with §A13 shapes:

uv run python -c "
import numpy as np

# A: (3 persons × 5 tenses), B: (5 tenses × 4 features)
np.random.seed(0)
A = np.random.randn(3, 5).astype(np.float32)
B = np.random.randn(5, 4).astype(np.float32)

ref = A @ B                                  # ground truth, shape (3, 4)
broken = np.einsum('ki,kj->ij', A, B)        # uses the *wrong* axis of A

print('ref shape:   ', ref.shape)
print('broken shape:', broken.shape)
print('max abs diff:', np.max(np.abs(ref - broken)))
print('allclose?    ', np.allclose(ref, broken))
"

Expected failure mode

ref shape:    (3, 4)
broken shape: (5, 4)         <-- different! M=3 became M=5
max abs diff: ValueError: operands could not be broadcast together

If shapes happen to match (e.g., square A), the failure goes silent:

uv run python -c "
import numpy as np
np.random.seed(0)
A = np.random.randn(5, 5).astype(np.float32)
B = np.random.randn(5, 4).astype(np.float32)
ref = A @ B
broken = np.einsum('ki,kj->ij', A, B)
print('shapes match:', ref.shape == broken.shape)
print('values match:', np.allclose(ref, broken))
print('max diff:    ', np.max(np.abs(ref - broken)))
"

Output: shapes match: True, values match: False, max diff ≈ 5–10. Shape-only tests pass. The bug ships.

Diagnostic

Three independent checks, in order of cost:

  1. Numerical equality against np.matmul or A @ B on a non-square pair. The shape mismatch crashes loudly; instant failure.
  2. A symbolic test on a known small case. Set A = [[1,2],[3,4]], B = [[5,6],[7,8]], hand-compute the answer: A @ B = [[19,22],[43,50]]. Compare. If the test says you got [[26,30],[38,44]], your axes are swapped (you actually computed A.T @ B).
  3. Property test with hypothesis: for random A, B, your_matmul(A, B) == A @ B. Run with 50 examples; failures point to the swap.

Lesson

A @ B contracts A's last axis with B's first axis. Swapping any other axis silently changes the computation — and because dimensions often align by accident in square cases, the bug ships unless you have a known-answer test.

Always include a single hand-computed (small) test against any new matmul. Phase 7's gradcheck and Phase 8's np.allclose(your_op, torch_op) are the same idea at a different grain.

References

  • Golub & Van Loan, Matrix Computations, §1.1 (matrix-multiplication conventions).
  • The NumPy einsum docs, especially the section on "implicit summation" — every Greek-letter subscript on the right that's missing on the left is contracted.