English · Español
Break — Invertir el eje de contracción en un matmul¶
🇪🇸 Rompe el matmul invirtiendo qué eje se contrae. Las formas siguen cuadrando pero los números son basura. Esta clase de bug es invisible en CI hasta que las pérdidas no convergen — perfecto para una lección de "verifica con un caso conocido".
Objetivo: cualquier matmul hecho a mano (en numpy o bucles puros en python) de lab/01-matmul-perf.md, o un naive_matmul nuevo.
Hipótesis¶
El aprendiz predice: "Cambiar np.einsum('ik,kj->ij', A, B) por np.einsum('ki,kj->ij', A, B) producirá silenciosamente un resultado con forma válida pero numéricamente incorrecto. Los tests que solo comprueban la forma de la salida pasarán; los tests que comprueban los valores reales fallarán."
El break¶
En tu wrapper de matmul:
def matmul(A: np.ndarray, B: np.ndarray) -> np.ndarray:
- return np.einsum('ik,kj->ij', A, B)
+ return np.einsum('ki,kj->ij', A, B)
Equivalente en forma de bucle:
for i in range(M):
for j in range(N):
for k in range(K):
- C[i, j] += A[i, k] * B[k, j]
+ C[i, j] += A[k, i] * B[k, j] # /break: se contrae el eje incorrecto de A
Procedimiento de ejecución¶
Usa un caso de test con respuesta conocida y formas del §A13:
uv run python -c "
import numpy as np
# A: (3 persons × 5 tenses), B: (5 tenses × 4 features)
np.random.seed(0)
A = np.random.randn(3, 5).astype(np.float32)
B = np.random.randn(5, 4).astype(np.float32)
ref = A @ B # ground truth, shape (3, 4)
broken = np.einsum('ki,kj->ij', A, B) # uses the *wrong* axis of A
print('ref shape: ', ref.shape)
print('broken shape:', broken.shape)
print('max abs diff:', np.max(np.abs(ref - broken)))
print('allclose? ', np.allclose(ref, broken))
"
Modo de fallo esperado¶
ref shape: (3, 4)
broken shape: (5, 4) <-- different! M=3 became M=5
max abs diff: ValueError: operands could not be broadcast together
Si las formas coinciden por casualidad (p. ej., A cuadrada), el fallo se vuelve silencioso:
uv run python -c "
import numpy as np
np.random.seed(0)
A = np.random.randn(5, 5).astype(np.float32)
B = np.random.randn(5, 4).astype(np.float32)
ref = A @ B
broken = np.einsum('ki,kj->ij', A, B)
print('shapes match:', ref.shape == broken.shape)
print('values match:', np.allclose(ref, broken))
print('max diff: ', np.max(np.abs(ref - broken)))
"
Salida: shapes match: True, values match: False, max diff ≈ 5–10. Los tests basados solo en forma pasan. El bug se publica.
Diagnóstico¶
Tres comprobaciones independientes, por orden de coste:
- Igualdad numérica contra
np.matmuloA @ Bcon un par no cuadrado. El desajuste de formas falla ruidosamente; fallo instantáneo. - Un test simbólico sobre un caso pequeño conocido. Define
A = [[1,2],[3,4]],B = [[5,6],[7,8]], calcula a mano la respuesta:A @ B = [[19,22],[43,50]]. Compara. Si el test dice que obtuviste[[26,30],[38,44]], tus ejes están invertidos (en realidad calculasteA.T @ B). - Test de propiedades con
hypothesis: paraA, Baleatorias,your_matmul(A, B) == A @ B. Ejecútalo con 50 ejemplos; los fallos apuntan al intercambio.
Lección¶
A @ B contrae el último eje de A con el primer eje de B. Cambiar cualquier otro eje altera silenciosamente el cálculo — y como las dimensiones a menudo encajan por accidente en casos cuadrados, el bug se publica salvo que tengas un test con respuesta conocida.
Incluye siempre un único test (pequeño) calculado a mano contra cualquier matmul nuevo. El gradcheck de la Fase 7 y el np.allclose(your_op, torch_op) de la Fase 8 son la misma idea con otra granularidad.
Referencias¶
- Golub & Van Loan, Matrix Computations, §1.1 (convenciones de multiplicación de matrices).
- La documentación de
einsumde NumPy, especialmente la sección sobre "implicit summation" — cada subíndice con letra griega a la derecha que falte a la izquierda se contrae.