Skip to content

English · Español

Lab 00 — Traza el dispatcher sobre linear(x, W, b)

Instrumentas una llamada simple a nn.Linear(64, 600) y registras cada decisión del dispatcher: qué key set, qué backend, qué kernel aten. Sale un log; sale una tabla de "esta llamada disparó estas N decisiones"; y sale la convicción de que torch.matmul no es magia.

Objetivo

Ejecuta torch.nn.functional.linear(x, W, b) para la forma del LM head del grammar MiniGPT (x: (2, 64), W: (600, 64), b: (600,)) y emite una traza completa de cada decisión del dispatcher. Reporta qué key_set se usó (por ejemplo, CPU, AutogradCPU), qué op de aten fue seleccionada (aten::linear, aten::addmm), y cómo la llamada se descompone en ops de más bajo nivel.

Setup

  • torch >= 2.1. Una build de CPU vale (la máquina de Borja no tiene CUDA, según CLAUDE.md §6).
  • theory/01-dispatcher-and-aten.md para el marco conceptual.

Tareas

Parte A — Llamada de baseline

import torch

torch.manual_seed(42)
x = torch.randn(2, 64, requires_grad=True)
W = torch.randn(600, 64, requires_grad=True)
b = torch.randn(600, requires_grad=True)

y = torch.nn.functional.linear(x, W, b)
print(y.shape, y.grad_fn)

Esperado: torch.Size([2, 600]), <AddmmBackward0 object at 0x...>.

La observación interesante: escribiste linear(x, W, b) pero el grad_fn es AddmmBackward0. Esa es la reescritura descendente del dispatcher en acción: linearaddmm → matmul + add.

Parte B — Imprime el conjunto de claves del dispatcher

print(x._dispatch_key_set())                 # CPU + AutogradCPU
print(torch.empty(1, dtype=torch.float16)._dispatch_key_set())  # CPU + AutogradCPU + Half

Registra lo que ves. El conjunto de claves es lo que el dispatcher usa para elegir un kernel.

Parte C — Usa TorchDispatchMode para loguear cada op

from torch.utils._python_dispatch import TorchDispatchMode

class LogDispatch(TorchDispatchMode):
    def __init__(self):
        self.calls = []
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        self.calls.append(str(func))
        return func(*args, **kwargs)

with LogDispatch() as logger:
    y = torch.nn.functional.linear(x, W, b)

for c in logger.calls:
    print(c)

Salida esperada (concreta):

aten.linear.default
aten.t.default
aten.addmm.default

(Los conteos y nombres exactos pueden variar según la versión de torch — fija la tuya en el informe.)

Parte D — Compara contra addmm directamente

with LogDispatch() as logger:
    y2 = torch.addmm(b, x, W.T)   # equivalent to linear(x, W, b)

print([str(c) for c in logger.calls])

Esperado: ['aten.addmm.default'] (o similar — sin aten.linear externo, sin aten.t porque ya transpusiste).

Esto muestra que nn.functional.linear se descompone hacia abajo en una secuencia de ops de aten; llamar a addmm directamente se salta la reescritura.

Parte E — Inspecciona las claves de dispatch para autograd

La clave de autograd está por encima de la clave de backend en el stack de dispatch. Con requires_grad=True, el motor encamina primero por AutogradCPU (que registra el nodo de backward), después re-dispatcha a CPU para el cómputo real.

import torch.autograd.profiler as profiler

with profiler.profile(record_shapes=True, with_stack=False) as prof:
    y = torch.nn.functional.linear(x, W, b)
    y.sum().backward()

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

Lee la tabla. Identifica: - La fila forward de aten::linear (o aten::addmm). - La fila backward de AddmmBackward0. - Sus formas de entrada (deberían incluir [2, 64], [600, 64], [600]).

Parte F — Escribe el informe

experiments/25-dispatcher-trace/REPORT.md:

  1. La versión de torch fijada (torch.__version__).
  2. La salida completa de LogDispatch para linear(x, W, b) (sólo forward).
  3. La salida completa de LogDispatch para addmm directo. Anota la diferencia.
  4. La impresión del conjunto de claves del dispatcher (con y sin requires_grad).
  5. El extracto de la tabla del profiler para forward + backward.
  6. 3-5 frases interpretando: "la llamada a linear se descompuso en t + addmm; la clave de autograd encaminó primero, después la clave de CPU; el backward se registró como AddmmBackward0."

Entregable

experiments/25-dispatcher-trace/: - trace.log — salida cruda de LogDispatch. - REPORT.md — los puntos anteriores. - manifest.json — versión de torch, semilla, timestamp, hash de código.

Aceptación

  • La traza muestra al menos aten.linear, aten.t, aten.addmm (en algún orden) para la llamada a linear.
  • La llamada directa a addmm muestra menos ops que la llamada a linear.
  • La tabla del profiler es legible; puedes señalar las filas de forward y backward.
  • El párrafo de interpretación del informe menciona que el conjunto de claves de autograd se sitúa por encima de la clave de CPU.

Pitfalls

  • TorchDispatchMode requiere torch reciente. Si tu versión fijada es < 2.0, esta API puede no estar presente. O actualiza (dentro de la fijación de la Fase 25) o usa torch.profiler como fuente de traza principal.
  • Los nombres de operadores cambian entre versiones menores de torch. aten.linear.default puede ser aten._linear o similar en builds más antiguas. Documenta lo que ves; no hagas asserts sobre nombres exactos.
  • Mezclar fp16 cambia el conjunto de claves. Si x.dtype=torch.float16, verás Half en el conjunto de claves y posiblemente una op distinta (aten.linear haciendo lowering primero a _to_copy). Empieza con fp32 para una traza limpia.
  • El backward no es visible en LogDispatch por defecto. El hook captura sólo el dispatch del forward. Para el backward, usa el profiler.

Stretch

  • Repite con x.dtype=torch.float16. Identifica la nueva clave en el conjunto.
  • Envuelve la llamada en torch.no_grad(). Confirma que la clave de autograd desaparece del routing (sin AutogradCPU).
  • Compara linear(x, W, b) con (x @ W.T) + b via LogDispatch. ¿Descomposición distinta?

Siguiente lab: lab/01-autograd-by-hand.md.