Skip to content

English · Español

02 — La multiplicación matricial como composición

🇪🇸 Matmul es la composición de dos mapas lineales: (AB)x = A(Bx). Esa definición — y no "el ratio O(N³)" — explica por qué hay batched matmul, por qué multi-head attention es paralelizable, y por qué la regla de las dimensiones internas funciona. Ejemplo central: E @ one_hot(i) es una búsqueda en una tabla de embeddings (Fase 13).

Esta es la página de teoría de la Fase 3. Re-deriva cada diagrama de abajo hasta que puedas dibujarlo en una servilleta.


Tres vistas de matmul

Vista 1 — suma de productos

La definición de libro de texto. Para A de shape (M, K) y B de shape (K, N):

\[ C[i, j] = \sum_{k=0}^{K-1} A[i, k] \cdot B[k, j] \]

C tiene shape (M, N). Total de multiplicaciones escalares: M × K × N. Total de sumas escalares: M × (K - 1) × N. Así que FLOPs ≈ 2 × M × K × N.

Esta es la vista que implementas cuando escribes una matmul ingenua (tres bucles anidados). Es correcta, e irrelevante para entender.

Vista 2 — composición de mapas lineales

Una matriz A de shape (M, K) representa un mapa lineal f: R^K → R^M. Una matriz B de shape (K, N) representa g: R^N → R^K. Su producto AB representa la composición f ∘ g: R^N → R^M. La regla de shape (M, K) @ (K, N) = (M, N) es literalmente la regla de composición de funciones "la dimensión de entrada de f (= K) debe coincidir con la dimensión de salida de g (= K)".

En código:

# f: R^K -> R^M
A.shape == (M, K)
# g: R^N -> R^K
B.shape == (K, N)
# f ∘ g: R^N -> R^M
(A @ B).shape == (M, N)

Aplica a un vector x de shape (N,):

  • B @ x es g(x), shape (K,).
  • A @ (B @ x) es f(g(x)), shape (M,).
  • (A @ B) @ x es lo mismo por asociatividad, calculado en orden distinto.

Esta es la vista que explica por qué matmul es la operación central. Cada capa de una red neuronal es "aplicar un mapa lineal (luego una no-linealidad)". Componer capas es componer mapas lineales es matmul.

Vista 3 — suma de productos externos

AB puede escribirse como una suma de K productos externos:

\[ AB = \sum_{k=0}^{K-1} A[:, k] \otimes B[k, :] \]

Donde A[:, k] es la k-ésima columna de A (shape (M, 1)) y B[k, :] es la k-ésima fila de B (shape (1, N)); su producto externo es (M, N). Suma K de ellos.

Esta vista importa para entender aproximaciones de bajo rango y LoRA: si te quedas solo con los r productos externos más grandes (seleccionados por SVD), obtienes la mejor aproximación de rango-r a AB.

La búsqueda de embedding §A13, revisitada

Recuerda la matriz de embedding E de shape (V, D) donde V = 600. Vector one-hot e_i (longitud V, todo ceros excepto un 1 en la posición i).

result = E.T @ e_i   # shape (D, V) @ (V,) = (D,)
# o equivalentemente:
result = np.einsum('vd,v->d', E, e_i)
# o equivalentemente:
result = E[i]        # indexación directa

Los tres devuelven la i-ésima fila de E — el embedding de la forma verbal i. Difieren en cuánto trabajo hace la máquina:

  • E.T @ e_i hace V × D multiplicaciones, de las cuales (V-1) × D son multiplicaciones por cero. Total de FLOPs: 2 V D = 2 × 600 × 64 = 76,800.
  • E[i] hace D cargas de memoria, cero multiplicaciones. Total de FLOPs: 0.

El camino rápido es la indexación — las tablas de embedding se acceden por búsqueda, no por matmul. Pero la interpretación matemática es multiplicación matriz-vector con un one-hot. La Fase 13 explora la vista de tabla de búsqueda; la Fase 17 (MiniGPT) realmente usa indexación.

Por qué importa esta distinción. Algún hardware (las TPU antiguas, los diseños NPU originales) no tienen operaciones gather eficientes. Implementan la búsqueda de embedding como una matmul one-hot. La matemática es idéntica; el rendimiento no. Conocer ambas vistas te permite razonar sobre lo que el hardware está haciendo realmente.

Batched matmul

Para tensores con ejes de batch principales, matmul se aplica a las dos últimas dimensiones, haciendo broadcast sobre los ejes principales:

Entrada A Entrada B Salida
(M, K) (K, N) (M, N)
(B, M, K) (B, K, N) (B, M, N)
(B, M, K) (K, N) (B, M, N) (B se broadcast)
(M, K) (B, K, N) (B, M, N) (M se broadcast sobre batch)
(B, H, M, K) (B, H, K, N) (B, H, M, N)

La última fila es el shape del matmul principal del multi-head attention: B batches × H cabezas × M queries × K keys.

El coste es exactamente el mismo por "instancia de matmul" — 2 M K N FLOPs cada una — multiplicado por el número de elementos de batch. El coste total para (B, H, M, K) @ (B, H, K, N) es 2 B H M K N FLOPs. Memorízalo.

Lectura de shapes desde el código

Tres ejercicios de inferencia de shape:

Ejemplo 1 — un bloque FFN de transformer

x.shape = (B, T, D)           # activaciones de entrada
W_1.shape = (D, D_ff)         # expansión FFN
W_2.shape = (D_ff, D)         # contracción FFN

h = x @ W_1                   # shape: (B, T, D_ff)
h = np.maximum(h, 0)          # ReLU; shape sin cambios
y = h @ W_2                   # shape: (B, T, D)

Cada matmul hace broadcast sobre (B, T) automáticamente. FLOPs totales: B × T × (D × D_ff + D_ff × D) × 2 = 4 × B × T × D × D_ff.

Para el MiniGPT de Borja (B=32, T=16, D=64, D_ff=256): 4 × 32 × 16 × 64 × 256 ≈ 33M FLOPs por FFN por forward pass.

Ejemplo 2 — cabeza de attention

Q.shape = (B, H, T, D_k)
K.shape = (B, H, T, D_k)
V.shape = (B, H, T, D_k)

scores = np.einsum('bhqd,bhkd->bhqk', Q, K)   # shape (B, H, T, T)
scores = scores / np.sqrt(D_k)
attn = stable_softmax(scores, axis=-1)        # shape (B, H, T, T)
output = np.einsum('bhqk,bhkd->bhqd', attn, V) # shape (B, H, T, D_k)

Dos einsums, ambos matmuls batched 4D. FLOPs totales: 2 × 2 × B × H × T × T × D_k = 4 B H T² D_k. Para el MiniGPT de Borja (B=32, H=4, T=16, D_k=16): 4 × 32 × 4 × 256 × 16 ≈ 2.1M FLOPs por bloque de attention.

Verás estas expresiones repetidamente. El trabajo de la Fase 3 es hacerlas mecánicas.

Ejemplo 3 — el clasificador de tiempos §A13

hidden.shape = (B, D)         # estado oculto post-attention para el batch
W_tense.shape = (D, 5)        # pesos del clasificador para 5 tiempos

logits = hidden @ W_tense     # shape: (B, 5)
probs = stable_softmax(logits, axis=-1)   # shape: (B, 5)

B × 5 = 160 predicciones de "¿qué tiempo verbal?" por cada forward pass. La pérdida (loss) CE contra etiquetas reales de tiempo impulsa el entrenamiento.

Operaciones especiales

Producto escalar (dot product)

np.dot(a, b) para a, b 1-D de shape (N,) devuelve un escalar. Equivalente a np.einsum('i,i->', a, b). FLOPs: 2N - 1.

Producto externo

np.outer(a, b) devuelve shape (M, N). Equivalente a np.einsum('i,j->ij', a, b). FLOPs: M × N (sin sumas).

Hadamard (elemento a elemento)

a * b devuelve shape (N,) (o shape broadcasted). NO es un matmul. FLOPs: N.

Bug común: escribir A * B cuando querías decir A @ B. El primero es elementwise (requiere mismo shape); el segundo es matmul (requiere (M, K) × (K, N)). Los mensajes de error de NumPy distinguen, pero PyTorch a veces silenciosamente hace broadcast de formas confusas. Usa el hábito del comentario de shape.

Rendimiento — el gap que tus ojos verán en el lab 01

En Python, la matmul ingenua de tres bucles sobre arrays fp32 de tamaño (1024, 1024) × (1024, 1024) tarda ~minutos. np.matmul de los mismos arrays tarda ~10 ms. El gap es 10⁴-10⁵×, mucho más amplio que el "50×" que predijo la Fase 1.

¿De dónde sale el gap?

  1. Sobrecarga del intérprete de Python. Un bucle for k in range(K) en Python son ~100 ns por iteración solo por el bytecode. Matmul ingenuo hace M × K × N = 10⁹ iteraciones, así que ~100 s solo por la sobrecarga del bucle.
  2. Sin SIMD. np.matmul usa AVX2 (8 multiplicaciones fp32 por instrucción). El triple bucle hace 1.
  3. Sin cache blocking. np.matmul bloquea para L1/L2, elevando la intensidad aritmética a ~100 FLOPs/byte. El matmul ingenuo está en el piso 0.25 (teoría 03-roofline-model.md en la Fase 1).
  4. Sin multi-threading. OpenBLAS usa los 4 cores. El ingenuo usa 1.

Efecto compuesto: 100 × 8 × 40 × 4 ≈ 100,000×. Eso es aproximadamente lo que verás.

La Fase 6 (Python para Ingeniería de IA) cubre (1). El lab 01 de la Fase 3 simplemente te hace ver el gap y apunta a la causa. La conclusión: vectoriza siempre a través de NumPy/BLAS; nunca escribas bucles internos en Python.

Atándolo todo — el cheatsheet de einsum

Para referencia de Borja, los patrones einsum más comunes en este currículo:

Operación Einsum Shape resultante
(M, K) @ (K, N) 'ij,jk->ik' (M, N)
Batched matmul (B, M, K) @ (B, K, N) 'bij,bjk->bik' (B, M, N)
E^T @ one_hot(i) (búsqueda de embedding) 'vd,v->d' (D,)
Embedding por lotes 'btv,vd->btd' (B, T, D)
Scores de attention Q @ K^T 'bhqd,bhkd->bhqk' (B, H, T, T)
Salida de attention attn @ V 'bhqk,bhkd->bhqd' (B, H, T, D_k)
Producto interno de Frobenius 'ij,ij->' ()
Traza 'ii->' ()
Diagonal 'ii->i' (N,)

Memoriza los seis primeros. El resto son derivables.

Problemas de práctica

Soluciones en solutions/02-matmul-and-shapes-ref.md (apertura de fase).

  1. Dados A.shape = (B, H, T, D_k) y B.shape = (B, H, T, D_k), escribe el einsum que calcula el producto escalar por cabeza sobre el último eje — es decir, shape (B, H, T). ¿Cómo se llama la operación en attention?
  2. El vocabulario §A13 tiene 600 formas verbales. Un clasificador pequeño tiene matriz de pesos W.shape = (5, 600) (5 tiempos). Escribe el einsum que, dado un one-hot de un token (V,), produce el vector de logits de 5 tiempos (5,). (Sí, esto es simplemente W @ one_hot, escrito en einsum.)
  3. Demuestra que el einsum 'ik,kj->ij' es asociativo: para tres matrices A, B, C, muestra que (AB)C = A(BC) escribiendo los índices.
  4. FLOPs para el bloque de attention del MiniGPT de Borja (B=32, H=4, T=16, D=64, D_k=16). Suma: proyecciones Q/K/V + scores de attention + salida de attention + proyección de salida. Compara con los FLOPs del FFN.
  5. ¿Por qué np.matmul es más rápido que for k in range(K): C += A[:, k:k+1] @ B[k:k+1, :]? Ambos calculan la misma suma de productos externos.

Recapitulación en un párrafo

Matmul (M, K) @ (K, N) = (M, N) es la composición de mapas lineales; la regla de la dimensión interna es la restricción de firma de composición de funciones. Equivalentemente, puede leerse como una suma de K productos externos de columna-de-A por fila-de-B, que es la base de las aproximaciones de bajo rango. Batched matmul hace broadcast de los ejes principales; multi-head attention es uno de esos batched matmul. El gap de rendimiento entre la matmul ingenua de Python y np.matmul es 10⁴-10⁵× y viene de sobrecarga del intérprete, falta de SIMD, falta de cache blocking y falta de paralelismo. Domina einsum como gramática unificadora y tu código será seguro por tipos por construcción.

Lo que esta página NO cubre

  • Precisión numérica de matmul (Fase 2; los tests usan rtol=1e-5).
  • Gradiente a través de matmul (Fase 4 + 8).
  • Matmul disperso (fuera de alcance).
  • Internals de los kernels GEMM en GPU (Fase 24).

Siguiente: theory/03-svd-and-rank.md.