English · Español
Lab 02 — Registra una op personalizada con autograd¶
Tomas el softmax que escribiste en Triton en Fase 24 (o un placeholder NumPy si no tienes CUDA), lo envuelves como
torch.library.custom_opcon backward registrado, y verificas que (a)gradchecklo aprueba y (b)torch.compilelo respeta como una caja negra. Esto es el patrón que Phase 27 reutiliza para Flash-Attention y Phase 26 para int-mm.
Objetivo¶
Registra un operador softmax_custom usando torch.library.custom_op, proporciona su forward y backward, verifica con torch.autograd.gradcheck, y confirma que torch.compile lo trata correctamente (como una frontera opaca o fusionado según corresponda).
Setup¶
torch >= 2.1(API custom_op).- Kernel de softmax en Triton de la Fase 24 si tienes CUDA. Si no: un sustituto de softmax respaldado por NumPy. El punto del laboratorio es el registro, no la velocidad del kernel.
theory/02-autograd-engine.mdpara el contexto de la fórmula de backward.
El forward y backward¶
Forward (numéricamente estable):
Backward (Jacobiano del softmax):
Equivalentemente: dx = s * (ds - (s * ds).sum(dim=-1, keepdim=True)).
Ésta es la derivación de la Fase 04 lab 00.
Tareas¶
Parte A — Implementa el forward y backward como funciones planas¶
import torch
def softmax_forward(x: torch.Tensor) -> torch.Tensor:
m = x.max(dim=-1, keepdim=True).values
e = (x - m).exp()
return e / e.sum(dim=-1, keepdim=True)
def softmax_backward(grad_out: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
# s is the saved forward output
return s * (grad_out - (s * grad_out).sum(dim=-1, keepdim=True))
(Si tienes CUDA + Triton, reemplaza softmax_forward por una llamada a kernel Triton. El backward queda igual — son ops de PyTorch sobre el tensor guardado.)
Parte B — Regístralo como op personalizada¶
from torch.library import custom_op, register_autograd
@custom_op("lynx_cortex::softmax", mutates_args=())
def softmax_custom(x: torch.Tensor) -> torch.Tensor:
return softmax_forward(x)
@softmax_custom.register_fake
def _(x):
return torch.empty_like(x) # shape-and-dtype only, no compute
def setup_context(ctx, inputs, output):
(x,) = inputs
ctx.save_for_backward(output) # save s, not x
def backward(ctx, grad_out):
(s,) = ctx.saved_tensors
return softmax_backward(grad_out, s)
register_autograd(softmax_custom, backward, setup_context=setup_context)
Tres cosas a notar:
mutates_args=()— declara que la op es pura (sin escrituras in-place). El pipeline de compile se apoya en esto.register_fake— una "función de forma" que permite al pipeline de compile/trace razonar sobre la forma de salida sin ejecutar el kernel real.setup_contextguarda la salida — el backward de softmax necesitas(el resultado), nox. Guardar la salida evita recomputar.
Parte C — Verifica con gradcheck¶
torch.manual_seed(0)
x = torch.randn(2, 64, dtype=torch.float64, requires_grad=True)
ok = torch.autograd.gradcheck(softmax_custom, (x,), eps=1e-6, atol=1e-5)
print("gradcheck:", ok)
gradcheck perturba cada elemento de entrada en ±eps, calcula gradientes por diferencias finitas y compara con el backward analítico. Usa fp64 — el gradcheck en fp32 falla rutinariamente sobre softmax por el suelo de precisión de rsqrt/exp. fp64 es el estándar.
Parte D — Verifica equivalencia con torch.softmax¶
torch.manual_seed(1)
x = torch.randn(8, 600, requires_grad=True)
y_custom = softmax_custom(x)
y_ref = torch.softmax(x, dim=-1)
print("forward max-err:", (y_custom - y_ref).abs().max().item()) # ~1e-7 at fp32
(y_custom.sum()).backward()
g_custom = x.grad.clone()
x.grad.zero_()
(y_ref.sum()).backward()
g_ref = x.grad.clone()
print("backward max-err:", (g_custom - g_ref).abs().max().item()) # ~1e-7 at fp32
Parte E — Úsalo dentro de torch.compile¶
@torch.compile
def model(x, W, b):
h = torch.nn.functional.linear(x, W, b)
return softmax_custom(h)
x = torch.randn(2, 64)
W = torch.randn(600, 64)
b = torch.randn(600)
y = model(x, W, b)
print(y.shape, y.sum().item())
Vuelve a ejecutar por segunda vez — torch.compile no debería lanzar error. Si lo hace, tienes un bug de registro (lo más probable es que register_fake devuelva una forma/dtype erróneos).
Parte F — Lee la salida de Inductor¶
Pon la variable de entorno para conservar los kernels generados:
O en Python:
En el log verás el Triton/C++ generado para las porciones compiladas. Tu softmax_custom aparecerá como una llamada opaca (no fusionada) — eso es lo esperado para custom_ops sin un lowering de Inductor registrado. Anótalo en el informe.
Parte G — Escribe el informe¶
experiments/25-custom-op/REPORT.md:
- Las matemáticas de forward+backward (LaTeX).
- El snippet de registro (Parte B).
- Línea
gradcheckPASS. - Error máximo de forward/backward vs
torch.softmax(Parte D). - Salida de
torch.compile: el extracto del log de Inductor mostrando la op personalizada como una llamada de caja negra. - Un párrafo: "Registré
softmax_customcon autograd;gradcheckpasó a fp64; coincidió con la referencia dentro de 1e-7 a fp32. Bajotorch.compile, la op aparece como una frontera opaca (sin lowering de Inductor registrado) — éste es el comportamiento correcto para un kernel personalizado; la Fase 27 proporcionará una versión fusionada."
Entregable¶
experiments/25-custom-op/:
- REPORT.md — los puntos anteriores.
- inductor.log — el extracto de la salida de Inductor.
- manifest.json.
Aceptación¶
gradcheckdevuelveTrue.- Los errores de forward y backward vs
torch.softmaxson< 1e-6a fp32. - El modelo
torch.compile'd se ejecuta sin lanzar error. - El log de Inductor muestra la op personalizada como una llamada en lugar de fusionada.
Pitfalls¶
- Guardar el tensor erróneo para el backward. El backward de softmax necesita la salida
s, no la entradax. Guardarxy recomputar el softmax en el backward funciona pero malgasta flops; hazlo de la forma canónica. - gradcheck en fp32 fallando.
gradcheckes brutalmente sensible. Usa entradas fp64 como en la Parte C. mutates_argsmal puesto. Si tu kernel escribe in-place (por ejemplo,x.exp_()), decláralo. Si no, el pipeline de compile asume pureza y tu modelo produce resultados erróneos bajotorch.compile.register_fakedevolviendo dtype erróneo.torch.empty_like(x)es correcto para softmax. Para ops que devuelven un dtype distinto, devuelve el correcto explícitamente.torch.compilerecompilando en cada llamada. Causa probable: una forma de entrada cambia. La función fake debe aceptar cualquier forma compatible — no debería hardcodear una.- Sin CUDA — Triton no disponible. Sáltate la sustitución del kernel Triton y usa el softmax con ops de PyTorch en la custom_op. El punto del laboratorio es el registro, no el kernel.
Stretch¶
- Registra un lowering de Inductor para tu op personalizada de modo que
torch.compilepueda fusionarla en el grafo circundante. Compara el runtime antes/después. - Añade un dispatch CPU-y-CUDA. Registra dos backends de modo que la op elija el kernel correcto automáticamente.
- Pruébalo bajo autocast. Envuelve la llamada en
torch.autocast("cpu", torch.bfloat16)y confirma que la op registrada lo maneja.
Siguiente lab: lab/03-compile-and-distributed.md.