Skip to content

English · Español

Lab 02 — Register a custom op with autograd

🇪🇸 Tomas el softmax que escribiste en Triton en Fase 24 (o un placeholder NumPy si no tienes CUDA), lo envuelves como torch.library.custom_op con backward registrado, y verificas que (a) gradcheck lo aprueba y (b) torch.compile lo respeta como una caja negra. Esto es el patrón que Phase 27 reutiliza para Flash-Attention y Phase 26 para int-mm.

Objective

Register a softmax_custom operator using torch.library.custom_op, provide its forward and backward, verify with torch.autograd.gradcheck, and confirm that torch.compile treats it correctly (as an opaque boundary or fused as appropriate).

Setup

  • torch >= 2.1 (custom_op API).
  • Phase 24's Triton softmax kernel if you have CUDA. Otherwise: a NumPy-backed softmax stand-in. The point of the lab is the registration, not the kernel speed.
  • theory/02-autograd-engine.md for the backward-formula context.

The forward and backward

Forward (numerically stable):

\[s_i = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}} \qquad m = \max_i x_i\]

Backward (Jacobian of softmax):

\[\frac{\partial L}{\partial x_i} = s_i \left( \frac{\partial L}{\partial s_i} - \sum_j s_j \frac{\partial L}{\partial s_j} \right)\]

Equivalently: dx = s * (ds - (s * ds).sum(dim=-1, keepdim=True)).

This is the derivation from Phase 04 lab 00.

Tasks

Part A — Implement the forward and backward as plain functions

import torch

def softmax_forward(x: torch.Tensor) -> torch.Tensor:
    m = x.max(dim=-1, keepdim=True).values
    e = (x - m).exp()
    return e / e.sum(dim=-1, keepdim=True)

def softmax_backward(grad_out: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
    # s is the saved forward output
    return s * (grad_out - (s * grad_out).sum(dim=-1, keepdim=True))

(If you have CUDA + Triton, replace softmax_forward with a Triton kernel call. The backward stays the same — it's PyTorch ops over the saved tensor.)

Part B — Register as a custom op

from torch.library import custom_op, register_autograd

@custom_op("lynx_cortex::softmax", mutates_args=())
def softmax_custom(x: torch.Tensor) -> torch.Tensor:
    return softmax_forward(x)

@softmax_custom.register_fake
def _(x):
    return torch.empty_like(x)   # shape-and-dtype only, no compute

def setup_context(ctx, inputs, output):
    (x,) = inputs
    ctx.save_for_backward(output)   # save s, not x

def backward(ctx, grad_out):
    (s,) = ctx.saved_tensors
    return softmax_backward(grad_out, s)

register_autograd(softmax_custom, backward, setup_context=setup_context)

Three things to notice:

  1. mutates_args=() — declares the op is pure (no in-place writes). The compile pipeline relies on this.
  2. register_fake — a "shape function" that lets the compile/trace pipeline reason about output shape without executing the real kernel.
  3. setup_context saves the output — softmax backward needs s (the result), not x. Saving the output avoids recomputing.

Part C — Verify with gradcheck

torch.manual_seed(0)
x = torch.randn(2, 64, dtype=torch.float64, requires_grad=True)
ok = torch.autograd.gradcheck(softmax_custom, (x,), eps=1e-6, atol=1e-5)
print("gradcheck:", ok)

gradcheck perturbs each input element by ±eps, computes finite-difference gradients, and compares to the analytical backward. Use fp64 — fp32 gradcheck routinely fails on softmax due to the rsqrt/exp precision floor. fp64 is the standard.

Part D — Verify equivalence to torch.softmax

torch.manual_seed(1)
x = torch.randn(8, 600, requires_grad=True)
y_custom = softmax_custom(x)
y_ref = torch.softmax(x, dim=-1)
print("forward max-err:", (y_custom - y_ref).abs().max().item())   # ~1e-7 at fp32

(y_custom.sum()).backward()
g_custom = x.grad.clone()
x.grad.zero_()
(y_ref.sum()).backward()
g_ref = x.grad.clone()
print("backward max-err:", (g_custom - g_ref).abs().max().item())  # ~1e-7 at fp32

Part E — Use inside torch.compile

@torch.compile
def model(x, W, b):
    h = torch.nn.functional.linear(x, W, b)
    return softmax_custom(h)

x = torch.randn(2, 64)
W = torch.randn(600, 64)
b = torch.randn(600)
y = model(x, W, b)
print(y.shape, y.sum().item())

Re-run a second time — torch.compile should not raise. If it does, you have a registration bug (most likely register_fake returning wrong shape/dtype).

Part F — Read the Inductor output

Set the env var to keep generated kernels:

TORCH_LOGS=output_code python your_script.py

Or in Python:

import os
os.environ["TORCH_LOGS"] = "output_code"

In the log you'll see the generated Triton/C++ for the compiled portions. Your softmax_custom will appear as an opaque call (not fused) — that's expected for custom_ops without an Inductor lowering registered. Note this in the report.

Part G — Write the report

experiments/25-custom-op/REPORT.md:

  1. The forward+backward math (LaTeX).
  2. The registration snippet (Part B).
  3. gradcheck PASS line.
  4. Forward/backward max-error vs torch.softmax (Part D).
  5. torch.compile output: the Inductor log excerpt showing the custom op as a black-box call.
  6. One paragraph: "I registered softmax_custom with autograd; gradcheck passed at fp64; it matched the reference within 1e-7 at fp32. Under torch.compile, the op appears as an opaque boundary (no Inductor lowering registered) — this is the right behavior for a custom kernel; Phase 27 will provide a fused version."

Deliverable

experiments/25-custom-op/: - REPORT.md — items above. - inductor.log — the Inductor output excerpt. - manifest.json.

Acceptance

  • gradcheck returns True.
  • Forward and backward errors vs torch.softmax are < 1e-6 at fp32.
  • torch.compile'd model runs without raising.
  • Inductor log shows the custom op as a call rather than fused.

Pitfalls

  • Saving the wrong tensor for backward. Softmax backward needs the output s, not the input x. Saving x and recomputing the softmax in backward works but wastes flops; do it the canonical way.
  • fp32 gradcheck failing. gradcheck is brutally sensitive. Use fp64 inputs as in Part C.
  • mutates_args set wrong. If your kernel writes in-place (e.g., x.exp_()), declare it. Otherwise the compile pipeline assumes purity and your model produces wrong results under torch.compile.
  • register_fake returning wrong dtype. torch.empty_like(x) is right for softmax. For ops that return a different dtype, return the right one explicitly.
  • torch.compile recompiling on every call. Likely cause: an input shape changes. The fake function must accept any compatible shape — it should not hard-code one.
  • No CUDA — Triton not available. Skip the Triton kernel substitution and use the PyTorch-op softmax in the custom_op. The lab's point is the registration, not the kernel.

Stretch

  • Register an Inductor lowering for your custom op so torch.compile can fuse it into the surrounding graph. Compare runtime before/after.
  • Add a CPU-and-CUDA dispatch. Register two backends so the op picks the right kernel automatically.
  • Test under autocast. Wrap the call in torch.autocast("cpu", torch.bfloat16) and confirm the registered op handles it.

Next lab: lab/03-compile-and-distributed.md.