English · Español
02 — The Autograd Engine¶
🇪🇸 El motor de autograd de PyTorch es exactamente lo que construimos en Fases 7–8: forward registra un grafo, backward lo recorre en reversa. Esta página formaliza la captura, los nodos
grad_fn, las hojas, y muestra cómotorch.library.custom_opregistra un backward para una operación nueva. El ejemplo corriente eslinear(x, W, b)con el LM head del grammar MiniGPT.
This page is the autograd engine, made explicit. After it you can walk the grad_fn chain for any forward pass, derive the backward by hand, and verify your derivation against PyTorch's .backward() to numerical agreement at fp32.
The two-line model¶
forward: every op on a requires_grad tensor records a grad_fn node into the graph.
backward: .backward() walks the graph in reverse-topological order, calling each grad_fn's backward formula.
That's it. The complexity is in (a) which ops record, (b) what the backward formula is for each, © how the graph is stored, (d) when it's freed. Phase 25 covers each.
The forward pass: graph capture¶
x = torch.randn(2, 64, requires_grad=True)
W = torch.randn(600, 64, requires_grad=True)
b = torch.randn(600, requires_grad=True)
y = torch.nn.functional.linear(x, W, b) # y.grad_fn = <AddmmBackward0>
loss = y.sum() # loss.grad_fn = <SumBackward0>
The autograd engine records two nodes — AddmmBackward0 and SumBackward0 — linked by an edge. The graph at this point:
Each node holds:
- Saved tensors needed for its backward formula. For
Addmm: savesxandW(needed for the gradient of the matmul). - Edges to parent nodes (or to leaves). Each edge knows which output of the parent feeds which input of this node.
- Backward function pointer — the C++ implementation of the gradient formula.
Leaves (x, W, b) have grad_fn = None (they weren't produced by an op) and is_leaf = True. They're the outputs of .backward(): the gradients accumulate into .grad on leaves.
The backward pass: reverse traversal¶
.backward():
- Starts with
loss(a scalar by default; if not, you passgradient=ones_like(loss)). - Calls
SumBackward0.backward(grad=1.0)→ returnsdy = ones_like(y). - Calls
AddmmBackward0.backward(grad=dy)→ returns three gradients (one per input): dx = dy @ W(shape(2, 64))dW = dy.T @ x(shape(600, 64))db = dy.sum(dim=0)(shape(600,))- Each gradient is added to the corresponding leaf's
.grad.
The traversal is reverse-topological. Cycles are forbidden (autograd errors if it detects one — rare but possible with hooks).
Deriving AddmmBackward0 by hand¶
The forward: \(y = b + x W^T\) (with broadcasting on \(b\)).
For a scalar loss \(L = \sum y\):
Chain rule:
In matrix form: \(\nabla_x L = (\nabla_y L) W\), shape (2, 600) @ (600, 64) = (2, 64). ✓
In matrix form: \(\nabla_W L = (\nabla_y L)^T x\), shape (600, 2) @ (2, 64) = (600, 64). ✓
In matrix form: \(\nabla_b L = (\nabla_y L).\text{sum}(\text{dim}=0)\), shape (600,). Equals the batch size 2 (each output position summed over the batch dim).
Verify in PyTorch:
loss.backward()
assert torch.allclose(x.grad, torch.ones_like(y) @ W)
assert torch.allclose(W.grad, torch.ones_like(y).T @ x)
assert torch.allclose(b.grad, torch.ones_like(y).sum(dim=0))
If those pass — and they do, to 1e-7 at fp32 — you've replicated AddmmBackward0's formula by hand. Lab 01 makes you do this exercise.
This is the whole content of the autograd engine: graph capture in forward, reverse traversal in backward, each node knowing its derivative. Phase 7 implemented this for scalars; Phase 8 for tensors; PyTorch's version is the same idea at scale.
Saved tensors and memory¶
Each grad_fn saves the tensors it needs for backward. AddmmBackward0 saves x and W (not b — its gradient doesn't depend on b). Saved tensors increase the peak memory during training (they're kept alive until backward runs).
Optimizations:
torch.utils.checkpoint: recompute saved tensors instead of storing. Trades compute for memory.torch.no_grad(): skip graph construction entirely. Used in inference..detach(): produce a new tensor withrequires_grad=False, breaking the graph at that point.
Lab 01 measures peak memory with and without torch.no_grad() for a forward pass through grammar MiniGPT.
When is the graph freed?¶
After .backward() completes — by default. Saved tensors are released. If you need to call .backward() twice on the same graph, use retain_graph=True.
Forgetting retain_graph=True when needed is a common error message: "Trying to backward through the graph a second time". The autograd engine eagerly frees saved tensors to save memory.
Custom autograd: torch.library.custom_op¶
The modern (torch 2.1+) API for registering a custom op with autograd:
import torch
from torch import Tensor
@torch.library.custom_op("mylib::softmax_triton", mutates_args=())
def softmax_triton(x: Tensor) -> Tensor:
# Implementation (calls into Triton kernel; Phase 24's softmax).
return triton_softmax_impl(x)
# Shape-inference for torch.compile / FakeTensor:
@softmax_triton.register_fake
def _(x):
return torch.empty_like(x)
# Backward formula:
def softmax_triton_backward(ctx, grad_output):
y = ctx.saved_tensors[0]
# Softmax backward: dy = y * (dL/dy - sum(y * dL/dy, dim=-1, keepdim=True))
return y * (grad_output - (y * grad_output).sum(dim=-1, keepdim=True))
def softmax_triton_setup_context(ctx, inputs, output):
ctx.save_for_backward(output)
softmax_triton.register_autograd(
softmax_triton_backward,
setup_context=softmax_triton_setup_context,
)
Now torch.ops.mylib.softmax_triton(x) behaves like a native PyTorch op:
- Dispatcher finds it.
- Autograd records
SoftmaxTritonBackwardin the graph during forward. .backward()calls the registered formula.torch.compilecan trace it (thanks to theregister_fakeshape-inference).gradcheckvalidates the backward formula numerically.
Lab 02 walks through this exact registration.
torch.autograd.gradcheck¶
PyTorch's tool for verifying a backward implementation:
from torch.autograd import gradcheck
x = torch.randn(4, 8, dtype=torch.float64, requires_grad=True)
gradcheck(torch.ops.mylib.softmax_triton, (x,), eps=1e-6, atol=1e-4)
gradcheck numerically estimates the gradient (via finite differences) and compares to the analytic backward. If they don't match, the backward formula is wrong. Use fp64 inputs (fp32 is too noisy for finite-difference checks).
This is the first test you run after registering a custom backward. If gradcheck fails, your backward formula has a bug; debug before integrating into a real training loop.
Common autograd errors¶
| Error | Cause | Fix |
|---|---|---|
RuntimeError: grad can be implicitly created only for scalar outputs |
Called .backward() on a non-scalar |
Pass gradient=ones_like(y) or call .sum().backward() |
RuntimeError: Trying to backward through the graph a second time |
Calling .backward() twice |
retain_graph=True on the first call |
RuntimeError: ... is at version N; expected version M |
In-place op on a saved tensor | Avoid in-place ops, or .clone() before |
grad_fn=None on a non-leaf |
The tensor was created in torch.no_grad() or with .detach() |
Re-create with grad tracking enabled |
gradcheck fails |
Backward formula wrong, or saved tensors wrong | Re-derive on paper; verify saved-context |
What you should now be able to do¶
- Walk the
grad_fnchain of any model's forward pass. - Derive the backward formula for any composition of
linear,relu,softmax,cross_entropy. - Use
torch.library.custom_opto register a new op with backward. - Use
gradcheckto numerically validate the backward. - Predict whether a model's peak memory is dominated by parameters, activations, or saved tensors.
What this page does NOT cover¶
torch.autograd.Function(the old API). Mentioned; lab 02 uses the moderncustom_opAPI exclusively.__torch_dispatch__for autograd interception. Niche; only relevant if you're building a parallel framework on top of PyTorch.- Second-order gradients (
create_graph=True). Used for meta-learning; out of curriculum scope. torch.funcfunctional transforms (grad,vmap). Phase 38 may revisit.
Next: theory/03-compile-and-distributed.md — the compile pipeline + distributed survey.