English · Español
Lab 00 — Trace the dispatcher on linear(x, W, b)¶
🇪🇸 Instrumentas una llamada simple a
nn.Linear(64, 600)y registras cada decisión del dispatcher: qué key set, qué backend, qué kernel ATen. Sale un log; sale una tabla de "esta llamada disparó estas N decisiones"; y sale la convicción de quetorch.matmulno es magia.
Objective¶
Run torch.nn.functional.linear(x, W, b) for the grammar MiniGPT's LM head shape (x: (2, 64), W: (600, 64), b: (600,)) and emit a complete trace of every dispatcher decision. Report which key_set was used (e.g., CPU, AutogradCPU), which ATen op was selected (aten::linear, aten::addmm), and how the call decomposes into lower-level ops.
Setup¶
torch >= 2.1. CPU build is fine (Borja's machine has no CUDA, perCLAUDE.md§6).theory/01-dispatcher-and-aten.mdfor the conceptual frame.
Tasks¶
Part A — Baseline call¶
import torch
torch.manual_seed(42)
x = torch.randn(2, 64, requires_grad=True)
W = torch.randn(600, 64, requires_grad=True)
b = torch.randn(600, requires_grad=True)
y = torch.nn.functional.linear(x, W, b)
print(y.shape, y.grad_fn)
Expected: torch.Size([2, 600]), <AddmmBackward0 object at 0x...>.
The interesting observation: you wrote linear(x, W, b) but the grad_fn is AddmmBackward0. That's the dispatcher's downward-rewrite at work: linear → addmm → matmul + add.
Part B — Print the dispatcher key set¶
print(x._dispatch_key_set()) # CPU + AutogradCPU
print(torch.empty(1, dtype=torch.float16)._dispatch_key_set()) # CPU + AutogradCPU + Half
Record what you see. The key set is what the dispatcher uses to pick a kernel.
Part C — Use TorchDispatchMode to log every op¶
from torch.utils._python_dispatch import TorchDispatchMode
class LogDispatch(TorchDispatchMode):
def __init__(self):
self.calls = []
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
self.calls.append(str(func))
return func(*args, **kwargs)
with LogDispatch() as logger:
y = torch.nn.functional.linear(x, W, b)
for c in logger.calls:
print(c)
Expected output (concrete):
(Counts and exact names may vary by torch version — pin yours in the report.)
Part D — Compare against addmm directly¶
with LogDispatch() as logger:
y2 = torch.addmm(b, x, W.T) # equivalent to linear(x, W, b)
print([str(c) for c in logger.calls])
Expected: ['aten.addmm.default'] (or similar — no outer aten.linear, no aten.t since you already transposed).
This shows that nn.functional.linear decomposes downward into a sequence of ATen ops; calling addmm directly skips the rewrite.
Part E — Inspect dispatch keys for autograd¶
The autograd key sits above the backend key in the dispatch stack. With requires_grad=True, the engine routes through AutogradCPU first (which records the backward node), then redispatches to CPU for the actual computation.
import torch.autograd.profiler as profiler
with profiler.profile(record_shapes=True, with_stack=False) as prof:
y = torch.nn.functional.linear(x, W, b)
y.sum().backward()
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
Read the table. Identify:
- The forward aten::linear (or aten::addmm) row.
- The backward AddmmBackward0 row.
- Their input shapes (should include [2, 64], [600, 64], [600]).
Part F — Write the report¶
experiments/25-dispatcher-trace/REPORT.md:
- The torch version pinned (
torch.__version__). - The full LogDispatch output for
linear(x, W, b)(forward only). - The full LogDispatch output for
addmmdirectly. Note the difference. - The dispatcher key set printout (with and without
requires_grad). - The profiler table excerpt for forward + backward.
- 3-5 sentences interpreting: "the
linearcall decomposed intot+addmm; the autograd key routed first, then the CPU key; backward registered asAddmmBackward0."
Deliverable¶
experiments/25-dispatcher-trace/:
- trace.log — raw LogDispatch output.
- REPORT.md — items above.
- manifest.json — torch version, seed, timestamp, code hash.
Acceptance¶
- Trace shows at least
aten.linear,aten.t,aten.addmm(in some order) for thelinearcall. - The
addmmdirect call shows fewer ops than thelinearcall. - Profiler table is readable; you can point at the forward and backward rows.
- The report's interpretation paragraph mentions the autograd key set sitting above the CPU key.
Pitfalls¶
TorchDispatchModerequires recent torch. If your pinned version is < 2.0, this API may not be present. Either upgrade (within Phase 25's pin) or usetorch.profileras the primary trace source.- Operator names change between torch minor versions.
aten.linear.defaultmay beaten._linearor similar in older builds. Document what you see; don't assert on exact names. - Mixing fp16 changes the key set. If
x.dtype=torch.float16, you'll seeHalfin the key set and possibly a different op (aten.linearlowering to_to_copyfirst). Start with fp32 for a clean trace. - Backward isn't visible in
LogDispatchby default. The hook captures forward dispatch only. For backward, use the profiler.
Stretch¶
- Repeat with
x.dtype=torch.float16. Identify the new key in the set. - Wrap the call in
torch.no_grad(). Confirm the autograd key disappears from the routing (noAutogradCPU). - Compare
linear(x, W, b)to(x @ W.T) + bvia LogDispatch. Different decomposition?
Next lab: lab/01-autograd-by-hand.md.