Skip to content

English · Español

Lab 00 — Trace the dispatcher on linear(x, W, b)

🇪🇸 Instrumentas una llamada simple a nn.Linear(64, 600) y registras cada decisión del dispatcher: qué key set, qué backend, qué kernel ATen. Sale un log; sale una tabla de "esta llamada disparó estas N decisiones"; y sale la convicción de que torch.matmul no es magia.

Objective

Run torch.nn.functional.linear(x, W, b) for the grammar MiniGPT's LM head shape (x: (2, 64), W: (600, 64), b: (600,)) and emit a complete trace of every dispatcher decision. Report which key_set was used (e.g., CPU, AutogradCPU), which ATen op was selected (aten::linear, aten::addmm), and how the call decomposes into lower-level ops.

Setup

  • torch >= 2.1. CPU build is fine (Borja's machine has no CUDA, per CLAUDE.md §6).
  • theory/01-dispatcher-and-aten.md for the conceptual frame.

Tasks

Part A — Baseline call

import torch

torch.manual_seed(42)
x = torch.randn(2, 64, requires_grad=True)
W = torch.randn(600, 64, requires_grad=True)
b = torch.randn(600, requires_grad=True)

y = torch.nn.functional.linear(x, W, b)
print(y.shape, y.grad_fn)

Expected: torch.Size([2, 600]), <AddmmBackward0 object at 0x...>.

The interesting observation: you wrote linear(x, W, b) but the grad_fn is AddmmBackward0. That's the dispatcher's downward-rewrite at work: linearaddmm → matmul + add.

Part B — Print the dispatcher key set

print(x._dispatch_key_set())                 # CPU + AutogradCPU
print(torch.empty(1, dtype=torch.float16)._dispatch_key_set())  # CPU + AutogradCPU + Half

Record what you see. The key set is what the dispatcher uses to pick a kernel.

Part C — Use TorchDispatchMode to log every op

from torch.utils._python_dispatch import TorchDispatchMode

class LogDispatch(TorchDispatchMode):
    def __init__(self):
        self.calls = []
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        self.calls.append(str(func))
        return func(*args, **kwargs)

with LogDispatch() as logger:
    y = torch.nn.functional.linear(x, W, b)

for c in logger.calls:
    print(c)

Expected output (concrete):

aten.linear.default
aten.t.default
aten.addmm.default

(Counts and exact names may vary by torch version — pin yours in the report.)

Part D — Compare against addmm directly

with LogDispatch() as logger:
    y2 = torch.addmm(b, x, W.T)   # equivalent to linear(x, W, b)

print([str(c) for c in logger.calls])

Expected: ['aten.addmm.default'] (or similar — no outer aten.linear, no aten.t since you already transposed).

This shows that nn.functional.linear decomposes downward into a sequence of ATen ops; calling addmm directly skips the rewrite.

Part E — Inspect dispatch keys for autograd

The autograd key sits above the backend key in the dispatch stack. With requires_grad=True, the engine routes through AutogradCPU first (which records the backward node), then redispatches to CPU for the actual computation.

import torch.autograd.profiler as profiler

with profiler.profile(record_shapes=True, with_stack=False) as prof:
    y = torch.nn.functional.linear(x, W, b)
    y.sum().backward()

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

Read the table. Identify: - The forward aten::linear (or aten::addmm) row. - The backward AddmmBackward0 row. - Their input shapes (should include [2, 64], [600, 64], [600]).

Part F — Write the report

experiments/25-dispatcher-trace/REPORT.md:

  1. The torch version pinned (torch.__version__).
  2. The full LogDispatch output for linear(x, W, b) (forward only).
  3. The full LogDispatch output for addmm directly. Note the difference.
  4. The dispatcher key set printout (with and without requires_grad).
  5. The profiler table excerpt for forward + backward.
  6. 3-5 sentences interpreting: "the linear call decomposed into t + addmm; the autograd key routed first, then the CPU key; backward registered as AddmmBackward0."

Deliverable

experiments/25-dispatcher-trace/: - trace.log — raw LogDispatch output. - REPORT.md — items above. - manifest.json — torch version, seed, timestamp, code hash.

Acceptance

  • Trace shows at least aten.linear, aten.t, aten.addmm (in some order) for the linear call.
  • The addmm direct call shows fewer ops than the linear call.
  • Profiler table is readable; you can point at the forward and backward rows.
  • The report's interpretation paragraph mentions the autograd key set sitting above the CPU key.

Pitfalls

  • TorchDispatchMode requires recent torch. If your pinned version is < 2.0, this API may not be present. Either upgrade (within Phase 25's pin) or use torch.profiler as the primary trace source.
  • Operator names change between torch minor versions. aten.linear.default may be aten._linear or similar in older builds. Document what you see; don't assert on exact names.
  • Mixing fp16 changes the key set. If x.dtype=torch.float16, you'll see Half in the key set and possibly a different op (aten.linear lowering to _to_copy first). Start with fp32 for a clean trace.
  • Backward isn't visible in LogDispatch by default. The hook captures forward dispatch only. For backward, use the profiler.

Stretch

  • Repeat with x.dtype=torch.float16. Identify the new key in the set.
  • Wrap the call in torch.no_grad(). Confirm the autograd key disappears from the routing (no AutogradCPU).
  • Compare linear(x, W, b) to (x @ W.T) + b via LogDispatch. Different decomposition?

Next lab: lab/01-autograd-by-hand.md.