Skip to content

English · Español

02 — The Autograd Engine

🇪🇸 El motor de autograd de PyTorch es exactamente lo que construimos en Fases 7–8: forward registra un grafo, backward lo recorre en reversa. Esta página formaliza la captura, los nodos grad_fn, las hojas, y muestra cómo torch.library.custom_op registra un backward para una operación nueva. El ejemplo corriente es linear(x, W, b) con el LM head del grammar MiniGPT.

This page is the autograd engine, made explicit. After it you can walk the grad_fn chain for any forward pass, derive the backward by hand, and verify your derivation against PyTorch's .backward() to numerical agreement at fp32.


The two-line model

forward:  every op on a requires_grad tensor records a grad_fn node into the graph.
backward: .backward() walks the graph in reverse-topological order, calling each grad_fn's backward formula.

That's it. The complexity is in (a) which ops record, (b) what the backward formula is for each, © how the graph is stored, (d) when it's freed. Phase 25 covers each.

The forward pass: graph capture

x = torch.randn(2, 64, requires_grad=True)
W = torch.randn(600, 64, requires_grad=True)
b = torch.randn(600, requires_grad=True)

y = torch.nn.functional.linear(x, W, b)   # y.grad_fn = <AddmmBackward0>
loss = y.sum()                             # loss.grad_fn = <SumBackward0>

The autograd engine records two nodes — AddmmBackward0 and SumBackward0 — linked by an edge. The graph at this point:

x ─┐
W ─┤
b ─┴→ AddmmBackward0 → y → SumBackward0 → loss

Each node holds:

  • Saved tensors needed for its backward formula. For Addmm: saves x and W (needed for the gradient of the matmul).
  • Edges to parent nodes (or to leaves). Each edge knows which output of the parent feeds which input of this node.
  • Backward function pointer — the C++ implementation of the gradient formula.

Leaves (x, W, b) have grad_fn = None (they weren't produced by an op) and is_leaf = True. They're the outputs of .backward(): the gradients accumulate into .grad on leaves.

The backward pass: reverse traversal

.backward():

  1. Starts with loss (a scalar by default; if not, you pass gradient=ones_like(loss)).
  2. Calls SumBackward0.backward(grad=1.0) → returns dy = ones_like(y).
  3. Calls AddmmBackward0.backward(grad=dy) → returns three gradients (one per input):
  4. dx = dy @ W (shape (2, 64))
  5. dW = dy.T @ x (shape (600, 64))
  6. db = dy.sum(dim=0) (shape (600,))
  7. Each gradient is added to the corresponding leaf's .grad.
loss.backward()
print(x.grad.shape, W.grad.shape, b.grad.shape)   # (2, 64), (600, 64), (600,)

The traversal is reverse-topological. Cycles are forbidden (autograd errors if it detects one — rare but possible with hooks).

Deriving AddmmBackward0 by hand

The forward: \(y = b + x W^T\) (with broadcasting on \(b\)).

For a scalar loss \(L = \sum y\):

\[\frac{\partial L}{\partial y_{ij}} = 1 \quad \text{(from sum)}\]

Chain rule:

\[\frac{\partial L}{\partial x_{ik}} = \sum_j \frac{\partial L}{\partial y_{ij}} \cdot \frac{\partial y_{ij}}{\partial x_{ik}} = \sum_j 1 \cdot W_{jk} = \sum_j W_{jk}\]

In matrix form: \(\nabla_x L = (\nabla_y L) W\), shape (2, 600) @ (600, 64) = (2, 64). ✓

\[\frac{\partial L}{\partial W_{jk}} = \sum_i \frac{\partial L}{\partial y_{ij}} \cdot \frac{\partial y_{ij}}{\partial W_{jk}} = \sum_i 1 \cdot x_{ik} = \sum_i x_{ik}\]

In matrix form: \(\nabla_W L = (\nabla_y L)^T x\), shape (600, 2) @ (2, 64) = (600, 64). ✓

\[\frac{\partial L}{\partial b_j} = \sum_i \frac{\partial L}{\partial y_{ij}} \cdot \frac{\partial y_{ij}}{\partial b_j} = \sum_i 1 \cdot 1 = 2\]

In matrix form: \(\nabla_b L = (\nabla_y L).\text{sum}(\text{dim}=0)\), shape (600,). Equals the batch size 2 (each output position summed over the batch dim).

Verify in PyTorch:

loss.backward()
assert torch.allclose(x.grad, torch.ones_like(y) @ W)
assert torch.allclose(W.grad, torch.ones_like(y).T @ x)
assert torch.allclose(b.grad, torch.ones_like(y).sum(dim=0))

If those pass — and they do, to 1e-7 at fp32 — you've replicated AddmmBackward0's formula by hand. Lab 01 makes you do this exercise.

This is the whole content of the autograd engine: graph capture in forward, reverse traversal in backward, each node knowing its derivative. Phase 7 implemented this for scalars; Phase 8 for tensors; PyTorch's version is the same idea at scale.

Saved tensors and memory

Each grad_fn saves the tensors it needs for backward. AddmmBackward0 saves x and W (not b — its gradient doesn't depend on b). Saved tensors increase the peak memory during training (they're kept alive until backward runs).

Optimizations:

  • torch.utils.checkpoint: recompute saved tensors instead of storing. Trades compute for memory.
  • torch.no_grad(): skip graph construction entirely. Used in inference.
  • .detach(): produce a new tensor with requires_grad=False, breaking the graph at that point.

Lab 01 measures peak memory with and without torch.no_grad() for a forward pass through grammar MiniGPT.

When is the graph freed?

After .backward() completes — by default. Saved tensors are released. If you need to call .backward() twice on the same graph, use retain_graph=True.

Forgetting retain_graph=True when needed is a common error message: "Trying to backward through the graph a second time". The autograd engine eagerly frees saved tensors to save memory.

Custom autograd: torch.library.custom_op

The modern (torch 2.1+) API for registering a custom op with autograd:

import torch
from torch import Tensor

@torch.library.custom_op("mylib::softmax_triton", mutates_args=())
def softmax_triton(x: Tensor) -> Tensor:
    # Implementation (calls into Triton kernel; Phase 24's softmax).
    return triton_softmax_impl(x)

# Shape-inference for torch.compile / FakeTensor:
@softmax_triton.register_fake
def _(x):
    return torch.empty_like(x)

# Backward formula:
def softmax_triton_backward(ctx, grad_output):
    y = ctx.saved_tensors[0]
    # Softmax backward: dy = y * (dL/dy - sum(y * dL/dy, dim=-1, keepdim=True))
    return y * (grad_output - (y * grad_output).sum(dim=-1, keepdim=True))

def softmax_triton_setup_context(ctx, inputs, output):
    ctx.save_for_backward(output)

softmax_triton.register_autograd(
    softmax_triton_backward,
    setup_context=softmax_triton_setup_context,
)

Now torch.ops.mylib.softmax_triton(x) behaves like a native PyTorch op:

  • Dispatcher finds it.
  • Autograd records SoftmaxTritonBackward in the graph during forward.
  • .backward() calls the registered formula.
  • torch.compile can trace it (thanks to the register_fake shape-inference).
  • gradcheck validates the backward formula numerically.

Lab 02 walks through this exact registration.

torch.autograd.gradcheck

PyTorch's tool for verifying a backward implementation:

from torch.autograd import gradcheck
x = torch.randn(4, 8, dtype=torch.float64, requires_grad=True)
gradcheck(torch.ops.mylib.softmax_triton, (x,), eps=1e-6, atol=1e-4)

gradcheck numerically estimates the gradient (via finite differences) and compares to the analytic backward. If they don't match, the backward formula is wrong. Use fp64 inputs (fp32 is too noisy for finite-difference checks).

This is the first test you run after registering a custom backward. If gradcheck fails, your backward formula has a bug; debug before integrating into a real training loop.

Common autograd errors

Error Cause Fix
RuntimeError: grad can be implicitly created only for scalar outputs Called .backward() on a non-scalar Pass gradient=ones_like(y) or call .sum().backward()
RuntimeError: Trying to backward through the graph a second time Calling .backward() twice retain_graph=True on the first call
RuntimeError: ... is at version N; expected version M In-place op on a saved tensor Avoid in-place ops, or .clone() before
grad_fn=None on a non-leaf The tensor was created in torch.no_grad() or with .detach() Re-create with grad tracking enabled
gradcheck fails Backward formula wrong, or saved tensors wrong Re-derive on paper; verify saved-context

What you should now be able to do

  1. Walk the grad_fn chain of any model's forward pass.
  2. Derive the backward formula for any composition of linear, relu, softmax, cross_entropy.
  3. Use torch.library.custom_op to register a new op with backward.
  4. Use gradcheck to numerically validate the backward.
  5. Predict whether a model's peak memory is dominated by parameters, activations, or saved tensors.

What this page does NOT cover

  • torch.autograd.Function (the old API). Mentioned; lab 02 uses the modern custom_op API exclusively.
  • __torch_dispatch__ for autograd interception. Niche; only relevant if you're building a parallel framework on top of PyTorch.
  • Second-order gradients (create_graph=True). Used for meta-learning; out of curriculum scope.
  • torch.func functional transforms (grad, vmap). Phase 38 may revisit.

Next: theory/03-compile-and-distributed.md — the compile pipeline + distributed survey.