English · Español
Lab 02 — Register a custom op with autograd¶
🇪🇸 Tomas el softmax que escribiste en Triton en Fase 24 (o un placeholder NumPy si no tienes CUDA), lo envuelves como
torch.library.custom_opcon backward registrado, y verificas que (a)gradchecklo aprueba y (b)torch.compilelo respeta como una caja negra. Esto es el patrón que Phase 27 reutiliza para Flash-Attention y Phase 26 para int-mm.
Objective¶
Register a softmax_custom operator using torch.library.custom_op, provide its forward and backward, verify with torch.autograd.gradcheck, and confirm that torch.compile treats it correctly (as an opaque boundary or fused as appropriate).
Setup¶
torch >= 2.1(custom_op API).- Phase 24's Triton softmax kernel if you have CUDA. Otherwise: a NumPy-backed softmax stand-in. The point of the lab is the registration, not the kernel speed.
theory/02-autograd-engine.mdfor the backward-formula context.
The forward and backward¶
Forward (numerically stable):
Backward (Jacobian of softmax):
Equivalently: dx = s * (ds - (s * ds).sum(dim=-1, keepdim=True)).
This is the derivation from Phase 04 lab 00.
Tasks¶
Part A — Implement the forward and backward as plain functions¶
import torch
def softmax_forward(x: torch.Tensor) -> torch.Tensor:
m = x.max(dim=-1, keepdim=True).values
e = (x - m).exp()
return e / e.sum(dim=-1, keepdim=True)
def softmax_backward(grad_out: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
# s is the saved forward output
return s * (grad_out - (s * grad_out).sum(dim=-1, keepdim=True))
(If you have CUDA + Triton, replace softmax_forward with a Triton kernel call. The backward stays the same — it's PyTorch ops over the saved tensor.)
Part B — Register as a custom op¶
from torch.library import custom_op, register_autograd
@custom_op("lynx_cortex::softmax", mutates_args=())
def softmax_custom(x: torch.Tensor) -> torch.Tensor:
return softmax_forward(x)
@softmax_custom.register_fake
def _(x):
return torch.empty_like(x) # shape-and-dtype only, no compute
def setup_context(ctx, inputs, output):
(x,) = inputs
ctx.save_for_backward(output) # save s, not x
def backward(ctx, grad_out):
(s,) = ctx.saved_tensors
return softmax_backward(grad_out, s)
register_autograd(softmax_custom, backward, setup_context=setup_context)
Three things to notice:
mutates_args=()— declares the op is pure (no in-place writes). The compile pipeline relies on this.register_fake— a "shape function" that lets the compile/trace pipeline reason about output shape without executing the real kernel.setup_contextsaves the output — softmax backward needss(the result), notx. Saving the output avoids recomputing.
Part C — Verify with gradcheck¶
torch.manual_seed(0)
x = torch.randn(2, 64, dtype=torch.float64, requires_grad=True)
ok = torch.autograd.gradcheck(softmax_custom, (x,), eps=1e-6, atol=1e-5)
print("gradcheck:", ok)
gradcheck perturbs each input element by ±eps, computes finite-difference gradients, and compares to the analytical backward. Use fp64 — fp32 gradcheck routinely fails on softmax due to the rsqrt/exp precision floor. fp64 is the standard.
Part D — Verify equivalence to torch.softmax¶
torch.manual_seed(1)
x = torch.randn(8, 600, requires_grad=True)
y_custom = softmax_custom(x)
y_ref = torch.softmax(x, dim=-1)
print("forward max-err:", (y_custom - y_ref).abs().max().item()) # ~1e-7 at fp32
(y_custom.sum()).backward()
g_custom = x.grad.clone()
x.grad.zero_()
(y_ref.sum()).backward()
g_ref = x.grad.clone()
print("backward max-err:", (g_custom - g_ref).abs().max().item()) # ~1e-7 at fp32
Part E — Use inside torch.compile¶
@torch.compile
def model(x, W, b):
h = torch.nn.functional.linear(x, W, b)
return softmax_custom(h)
x = torch.randn(2, 64)
W = torch.randn(600, 64)
b = torch.randn(600)
y = model(x, W, b)
print(y.shape, y.sum().item())
Re-run a second time — torch.compile should not raise. If it does, you have a registration bug (most likely register_fake returning wrong shape/dtype).
Part F — Read the Inductor output¶
Set the env var to keep generated kernels:
Or in Python:
In the log you'll see the generated Triton/C++ for the compiled portions. Your softmax_custom will appear as an opaque call (not fused) — that's expected for custom_ops without an Inductor lowering registered. Note this in the report.
Part G — Write the report¶
experiments/25-custom-op/REPORT.md:
- The forward+backward math (LaTeX).
- The registration snippet (Part B).
gradcheckPASS line.- Forward/backward max-error vs
torch.softmax(Part D). torch.compileoutput: the Inductor log excerpt showing the custom op as a black-box call.- One paragraph: "I registered
softmax_customwith autograd;gradcheckpassed at fp64; it matched the reference within 1e-7 at fp32. Undertorch.compile, the op appears as an opaque boundary (no Inductor lowering registered) — this is the right behavior for a custom kernel; Phase 27 will provide a fused version."
Deliverable¶
experiments/25-custom-op/:
- REPORT.md — items above.
- inductor.log — the Inductor output excerpt.
- manifest.json.
Acceptance¶
gradcheckreturnsTrue.- Forward and backward errors vs
torch.softmaxare< 1e-6at fp32. torch.compile'd model runs without raising.- Inductor log shows the custom op as a call rather than fused.
Pitfalls¶
- Saving the wrong tensor for backward. Softmax backward needs the output
s, not the inputx. Savingxand recomputing the softmax in backward works but wastes flops; do it the canonical way. - fp32 gradcheck failing.
gradcheckis brutally sensitive. Use fp64 inputs as in Part C. mutates_argsset wrong. If your kernel writes in-place (e.g.,x.exp_()), declare it. Otherwise the compile pipeline assumes purity and your model produces wrong results undertorch.compile.register_fakereturning wrong dtype.torch.empty_like(x)is right for softmax. For ops that return a different dtype, return the right one explicitly.torch.compilerecompiling on every call. Likely cause: an input shape changes. The fake function must accept any compatible shape — it should not hard-code one.- No CUDA — Triton not available. Skip the Triton kernel substitution and use the PyTorch-op softmax in the custom_op. The lab's point is the registration, not the kernel.
Stretch¶
- Register an Inductor lowering for your custom op so
torch.compilecan fuse it into the surrounding graph. Compare runtime before/after. - Add a CPU-and-CUDA dispatch. Register two backends so the op picks the right kernel automatically.
- Test under autocast. Wrap the call in
torch.autocast("cpu", torch.bfloat16)and confirm the registered op handles it.
Next lab: lab/03-compile-and-distributed.md.