Skip to content

English · Español

Lab 02 — Registra una op personalizada con autograd

Tomas el softmax que escribiste en Triton en Fase 24 (o un placeholder NumPy si no tienes CUDA), lo envuelves como torch.library.custom_op con backward registrado, y verificas que (a) gradcheck lo aprueba y (b) torch.compile lo respeta como una caja negra. Esto es el patrón que Phase 27 reutiliza para Flash-Attention y Phase 26 para int-mm.

Objetivo

Registra un operador softmax_custom usando torch.library.custom_op, proporciona su forward y backward, verifica con torch.autograd.gradcheck, y confirma que torch.compile lo trata correctamente (como una frontera opaca o fusionado según corresponda).

Setup

  • torch >= 2.1 (API custom_op).
  • Kernel de softmax en Triton de la Fase 24 si tienes CUDA. Si no: un sustituto de softmax respaldado por NumPy. El punto del laboratorio es el registro, no la velocidad del kernel.
  • theory/02-autograd-engine.md para el contexto de la fórmula de backward.

El forward y backward

Forward (numéricamente estable):

\[s_i = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}} \qquad m = \max_i x_i\]

Backward (Jacobiano del softmax):

\[\frac{\partial L}{\partial x_i} = s_i \left( \frac{\partial L}{\partial s_i} - \sum_j s_j \frac{\partial L}{\partial s_j} \right)\]

Equivalentemente: dx = s * (ds - (s * ds).sum(dim=-1, keepdim=True)).

Ésta es la derivación de la Fase 04 lab 00.

Tareas

Parte A — Implementa el forward y backward como funciones planas

import torch

def softmax_forward(x: torch.Tensor) -> torch.Tensor:
    m = x.max(dim=-1, keepdim=True).values
    e = (x - m).exp()
    return e / e.sum(dim=-1, keepdim=True)

def softmax_backward(grad_out: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
    # s is the saved forward output
    return s * (grad_out - (s * grad_out).sum(dim=-1, keepdim=True))

(Si tienes CUDA + Triton, reemplaza softmax_forward por una llamada a kernel Triton. El backward queda igual — son ops de PyTorch sobre el tensor guardado.)

Parte B — Regístralo como op personalizada

from torch.library import custom_op, register_autograd

@custom_op("lynx_cortex::softmax", mutates_args=())
def softmax_custom(x: torch.Tensor) -> torch.Tensor:
    return softmax_forward(x)

@softmax_custom.register_fake
def _(x):
    return torch.empty_like(x)   # shape-and-dtype only, no compute

def setup_context(ctx, inputs, output):
    (x,) = inputs
    ctx.save_for_backward(output)   # save s, not x

def backward(ctx, grad_out):
    (s,) = ctx.saved_tensors
    return softmax_backward(grad_out, s)

register_autograd(softmax_custom, backward, setup_context=setup_context)

Tres cosas a notar:

  1. mutates_args=() — declara que la op es pura (sin escrituras in-place). El pipeline de compile se apoya en esto.
  2. register_fake — una "función de forma" que permite al pipeline de compile/trace razonar sobre la forma de salida sin ejecutar el kernel real.
  3. setup_context guarda la salida — el backward de softmax necesita s (el resultado), no x. Guardar la salida evita recomputar.

Parte C — Verifica con gradcheck

torch.manual_seed(0)
x = torch.randn(2, 64, dtype=torch.float64, requires_grad=True)
ok = torch.autograd.gradcheck(softmax_custom, (x,), eps=1e-6, atol=1e-5)
print("gradcheck:", ok)

gradcheck perturba cada elemento de entrada en ±eps, calcula gradientes por diferencias finitas y compara con el backward analítico. Usa fp64 — el gradcheck en fp32 falla rutinariamente sobre softmax por el suelo de precisión de rsqrt/exp. fp64 es el estándar.

Parte D — Verifica equivalencia con torch.softmax

torch.manual_seed(1)
x = torch.randn(8, 600, requires_grad=True)
y_custom = softmax_custom(x)
y_ref = torch.softmax(x, dim=-1)
print("forward max-err:", (y_custom - y_ref).abs().max().item())   # ~1e-7 at fp32

(y_custom.sum()).backward()
g_custom = x.grad.clone()
x.grad.zero_()
(y_ref.sum()).backward()
g_ref = x.grad.clone()
print("backward max-err:", (g_custom - g_ref).abs().max().item())  # ~1e-7 at fp32

Parte E — Úsalo dentro de torch.compile

@torch.compile
def model(x, W, b):
    h = torch.nn.functional.linear(x, W, b)
    return softmax_custom(h)

x = torch.randn(2, 64)
W = torch.randn(600, 64)
b = torch.randn(600)
y = model(x, W, b)
print(y.shape, y.sum().item())

Vuelve a ejecutar por segunda vez — torch.compile no debería lanzar error. Si lo hace, tienes un bug de registro (lo más probable es que register_fake devuelva una forma/dtype erróneos).

Parte F — Lee la salida de Inductor

Pon la variable de entorno para conservar los kernels generados:

TORCH_LOGS=output_code python your_script.py

O en Python:

import os
os.environ["TORCH_LOGS"] = "output_code"

En el log verás el Triton/C++ generado para las porciones compiladas. Tu softmax_custom aparecerá como una llamada opaca (no fusionada) — eso es lo esperado para custom_ops sin un lowering de Inductor registrado. Anótalo en el informe.

Parte G — Escribe el informe

experiments/25-custom-op/REPORT.md:

  1. Las matemáticas de forward+backward (LaTeX).
  2. El snippet de registro (Parte B).
  3. Línea gradcheck PASS.
  4. Error máximo de forward/backward vs torch.softmax (Parte D).
  5. Salida de torch.compile: el extracto del log de Inductor mostrando la op personalizada como una llamada de caja negra.
  6. Un párrafo: "Registré softmax_custom con autograd; gradcheck pasó a fp64; coincidió con la referencia dentro de 1e-7 a fp32. Bajo torch.compile, la op aparece como una frontera opaca (sin lowering de Inductor registrado) — éste es el comportamiento correcto para un kernel personalizado; la Fase 27 proporcionará una versión fusionada."

Entregable

experiments/25-custom-op/: - REPORT.md — los puntos anteriores. - inductor.log — el extracto de la salida de Inductor. - manifest.json.

Aceptación

  • gradcheck devuelve True.
  • Los errores de forward y backward vs torch.softmax son < 1e-6 a fp32.
  • El modelo torch.compile'd se ejecuta sin lanzar error.
  • El log de Inductor muestra la op personalizada como una llamada en lugar de fusionada.

Pitfalls

  • Guardar el tensor erróneo para el backward. El backward de softmax necesita la salida s, no la entrada x. Guardar x y recomputar el softmax en el backward funciona pero malgasta flops; hazlo de la forma canónica.
  • gradcheck en fp32 fallando. gradcheck es brutalmente sensible. Usa entradas fp64 como en la Parte C.
  • mutates_args mal puesto. Si tu kernel escribe in-place (por ejemplo, x.exp_()), decláralo. Si no, el pipeline de compile asume pureza y tu modelo produce resultados erróneos bajo torch.compile.
  • register_fake devolviendo dtype erróneo. torch.empty_like(x) es correcto para softmax. Para ops que devuelven un dtype distinto, devuelve el correcto explícitamente.
  • torch.compile recompilando en cada llamada. Causa probable: una forma de entrada cambia. La función fake debe aceptar cualquier forma compatible — no debería hardcodear una.
  • Sin CUDA — Triton no disponible. Sáltate la sustitución del kernel Triton y usa el softmax con ops de PyTorch en la custom_op. El punto del laboratorio es el registro, no el kernel.

Stretch

  • Registra un lowering de Inductor para tu op personalizada de modo que torch.compile pueda fusionarla en el grafo circundante. Compara el runtime antes/después.
  • Añade un dispatch CPU-y-CUDA. Registra dos backends de modo que la op elija el kernel correcto automáticamente.
  • Pruébalo bajo autocast. Envuelve la llamada en torch.autocast("cpu", torch.bfloat16) y confirma que la op registrada lo maneja.

Siguiente lab: lab/03-compile-and-distributed.md.