English · Español
02 — RNN, GRU, LSTM: recurrence as a state machine¶
🇪🇸 Una RNN es una función de estado: lee un token, actualiza un vector "memoria", emite una predicción, y repite. Lo importante es lo que se gana (memoria distribuida que generaliza, no como un n-grama) y lo que se pierde (la memoria es de tamaño fijo y el cálculo no se puede paralelizar sobre la secuencia).
This file derives three closely-related architectures — vanilla RNN, GRU, LSTM — as a family of state machines, each one a patch to the previous one's failure mode.
The frame: a token-sequence machine¶
A language model is a function that consumes tokens one at a time and emits a distribution over the next token. An n-gram does this with a count-table indexed by the previous \(n - 1\) tokens. The n-gram has no internal state; everything it knows about the prefix is the literal identity of the last \(n - 1\) tokens.
A recurrent neural network does this with a learned, continuous, fixed-dimensional state vector \(h_t \in \mathbb{R}^d\):
where: - \(x_t\) is the embedding of the \(t\)-th input token (from Phase 13); - \(h_t\) is the hidden state — a learned summary of the prefix \(w_1, \ldots, w_t\); - \(f_\theta\) is the recurrence function with parameters \(\theta\); - \(g_\theta\) is the output head, mapping the state to logits over the vocabulary.
The hidden state has fixed dimension \(d\) (typically 32–256 for our corpus). All information about the prefix has to fit in that vector. This is the defining constraint of recurrent models.
Vanilla RNN (Elman, 1990)¶
The simplest possible parametrization:
with parameters: - \(W_{hh} \in \mathbb{R}^{d \times d}\) — state transition matrix; - \(W_{xh} \in \mathbb{R}^{d \times d_\text{embed}}\) — input projection; - \(W_{ho} \in \mathbb{R}^{|V| \times d}\) — output projection; - \(b_h, b_o\) — biases.
That's it. Four matrices, two biases. For \(d = 32\), \(|V| = 64\), \(d_\text{embed} = 16\), the parameter count is \(32^2 + 32 \cdot 16 + 64 \cdot 32 + 32 + 64 = 1024 + 512 + 2048 + 32 + 64 = 3680\) parameters. The model is genuinely small.
Forward pass on the canonical example ["I", "work", ",", "you", "work", ",", "he"]:
h_0 = zeros(d) ← initial state
x_1 = embed("I")
h_1 = tanh(W_hh @ h_0 + W_xh @ x_1 + b_h)
x_2 = embed("work")
h_2 = tanh(W_hh @ h_1 + W_xh @ x_2 + b_h)
x_3 = embed(",")
h_3 = tanh(W_hh @ h_2 + W_xh @ x_3 + b_h)
...
x_7 = embed("he")
h_7 = tanh(W_hh @ h_6 + W_xh @ x_7 + b_h)
y_hat = W_ho @ h_7 + b_o ← logits over V; softmax → P(next | prefix)
If trained, the model has learned to extract from the prefix I work, you work, he whatever signal indicates "next token should be works". That signal lives in \(h_7\).
This is the whole model. No attention. No multi-head. No layers (well, one layer; you can stack RNNs but rarely deeply). Forward pass costs \(O(T \cdot d^2)\) in time and \(O(d)\) memory per step (plus the parameter memory, which is constant).
What the hidden state can and cannot encode¶
The hidden state \(h_t \in \mathbb{R}^d\) is a fixed-capacity bottleneck. To predict \(w_t\), the model has access only to \(h_{t-1}\) and \(x_t\). Everything from the prefix \(w_1, \ldots, w_{t-1}\) has to live in \(h_{t-1}\).
Two implications:
- The state must summarize. With \(d = 32\), a state can encode (loosely) ~32 bits of information about the prefix. For our corpus, that's enough to encode the subject pronoun, the tense, the auxiliary phase — but not the verbatim sequence of 50 prior tokens.
- The state is overwritten every step. Each new \(x_t\) enters and reshapes \(h_t\). Information from \(h_{t-1}\) that isn't reinforced gets diluted. For the prefix
I work, you work, he, by the time we reach token 7, the pronounI(token 1) has been through 7 applications of \(W_{hh}\). Whether the model "remembers"Idepends on whether \(W_{hh}\) preserved that signal — which it usually doesn't (Phase 14 theory file 03 explains why).
The first implication is a good property: distributed, learned representations beat sparse n-gram indicators on tasks that generalize. The second is a bad property: it's the seed of the vanishing gradient problem and the reason long-range dependencies are hard.
The two failures of vanilla RNNs¶
After ~30 years of trying, the field converged on two precise failures:
- Gradient vanishing / exploding through time. Derived in
theory/03-vanishing-gradient.md. In short: the gradient from a late-step loss to an early-step input flows through repeated multiplication by \(W_{hh}\). Whether the result vanishes or explodes is determined by the eigenvalues of \(W_{hh}\). Stabilizing this is hard. - Serial compute. \(h_t\) depends on \(h_{t-1}\). You cannot compute \(h_t\) until \(h_{t-1}\) exists. Across a sequence of length \(T\), this is an inherently serial chain — no amount of GPU parallelism helps. A transformer's attention layer, by contrast, computes all \(T\) outputs in parallel via a single \(T \times T\) matmul.
The first failure is what motivated LSTM/GRU. The second failure is what motivated attention. Note: LSTM/GRU patches failure 1 but does nothing for failure 2. Attention patches both.
GRU (Cho et al., 2014)¶
The GRU (Gated Recurrent Unit) is a modification of the vanilla RNN that adds two gates — small networks that decide how much of the past to keep and how much new information to absorb.
Definition:
where \(\sigma\) is the sigmoid, \(\odot\) is elementwise multiplication, and \([a, b]\) denotes concatenation along the feature axis.
The key structural change is the last line. Instead of overwriting \(h_{t-1}\) with \(\tanh(...)\), the GRU forms a convex combination of the old state and a new candidate. If \(z_t \approx 0\), the new state is basically the old state; if \(z_t \approx 1\), it's the new candidate.
Why this matters. The vanilla RNN's recurrence is multiplicative — every step is "matrix-multiply by \(W_{hh}\) then nonlinearity". Repeated multiplication contracts or expands signals (the vanishing/exploding gradient story). The GRU's recurrence has an additive path that lets information flow through time without being multiplied by anything when \(z_t \approx 0\). Gradients can flow back through that path without contraction.
Parameter count. A GRU has three weight matrices each of shape \(d \times (d + d_\text{embed})\), plus biases, so \(\sim 3 d (d + d_\text{embed})\) parameters. For \(d = 32, d_\text{embed} = 16\), that's \(3 \cdot 32 \cdot 48 = 4608\) parameters — about \(1.25\times\) the vanilla RNN. The cost of the patch is modest.
Worked intuition on our corpus. When the RNN sees I work, you work, he, the GRU can learn to set \(z_t \approx 0\) when processing you work, he so that the subject-pronoun signal from I (encoded in \(h_1\)) propagates forward with minimal decay. The reset gate \(r_t\) similarly lets the model decide when to "forget" the prior subject (e.g., when it sees a separator). Whether the model actually learns this is a training question (Phase 18); the GRU at least makes it learnable.
LSTM (Hochreiter & Schmidhuber, 1997)¶
The LSTM (Long Short-Term Memory) is the older, more elaborate cousin of the GRU. It introduces a separate cell state \(c_t\) alongside the hidden state \(h_t\), with three gates instead of two.
Sketch (we do not derive the backward pass):
The cell-state update is the heart of the LSTM. It is a convex combination of the previous cell state and a new candidate — same idea as the GRU. The path \(c_{t-1} \to c_t\) has no matrix multiplication (only elementwise gating), so gradients flow through it without contraction.
LSTM vs GRU. Empirically, they perform within a few percent of each other on most tasks. LSTM has more parameters (four weight matrices vs three) and one more gate (the explicit forget gate vs GRU's coupled 1 - z_t). LSTM is the older standard; GRU is simpler and often preferred when you want a recurrent baseline.
We implement the vanilla RNN and GRU in Phase 14's lab. The LSTM is sketched in theory only — its math is one page of work, but training and testing it would duplicate what the GRU already shows.
What recurrent models do well, on our corpus¶
Three things:
- Local consistency.
I workconsistently leads to certain continuations;heconsistently leads to others. An RNN learns these regularities in the embeddings and the recurrence simultaneously, sharing representations across patterns (unlike an n-gram). - Soft generalization to unseen combinations. If the training set has
I work, you work, he worksandI play, you play, ...but nothe plays, an RNN with shared embeddings has a chance of predictingplayscorrectly because its representation ofplayis nearworkin embedding space. An n-gram has zero chance. - Constant-memory inference. An RNN's state is fixed-size regardless of sequence length. The transformer's KV cache (Phase 22) grows linearly with sequence length. This is why people are revisiting recurrent ideas (Mamba, RWKV) for very long contexts. We mention it in passing; Phase 36 territory.
What recurrent models do badly, on our corpus and in general¶
- Long-range dependencies. The future tense
he is going to workis a 4-token chain. By the time the model is generatingwork, it must have rememberedgoing to. An RNN with \(d = 32\) and naive initialization will lose this signal within ~10 steps due to vanishing gradients. The GRU patches this somewhat, but only by training the gates to preserve the signal — which itself requires gradients to flow backward through many steps. - Cross-paradigm transfer. "I work / yo trabajo" → "I worked / yo trabajé". The model has to learn that the
-edEnglish suffix corresponds to the-éSpanish ending for-arverbs. An RNN can learn this if the corpus shows enough examples, but it has no architectural prior toward such alignments — they have to be discovered in the embedding space + recurrence weights. - Inability to parallelize over the sequence axis. This is the killer for scaling. A 1000-token document goes through 1000 sequential RNN steps; nothing can parallelize this. Transformers do it in a single matmul over the entire sequence (Phase 15).
Why we still implement them¶
Two reasons:
- The forward pass is mechanically illuminating. When you watch \(h_t\) evolve token by token on
I work, you work, he, you can see (loosely) how the state encodes "we're now in the third pronoun, post-separator, post-work-bigram" — or not, depending on initialization. This is a sensation you cannot get from an n-gram (no state) or from a transformer (the state is the entire context, hard to read in one go). - Phase 18 needs a baseline. We need a real baseline number to compare the trained Mini-GPT against. The n-gram from theory 01 is one baseline; an untrained RNN's logits give another (random) baseline; a trained RNN would give the strongest comparison. Phase 14 stops at "forward pass only" — Phase 18 could optionally train an RNN for full comparison, but the spec says no.
What this phase does NOT cover¶
- Bidirectional RNNs. A BiRNN processes the sequence both left-to-right and right-to-left. Useful for tagging tasks, irrelevant for language modeling (we don't see the future when predicting the next token).
- Stacked / multi-layer RNNs. Stacking RNNs (output of one becomes input of the next) is straightforward but introduces depth-wise vanishing on top of time-wise vanishing. Out of scope.
- Teacher forcing vs sampling during training. Phase 18 territory.
- Modern recurrent revivals (Mamba, S4, RWKV, RetNet). Phase 36 (frontier architectures). They share the "linear recurrence + selective state" idea, which is conceptually upstream of the LSTM but uses very different math. Mentioned for vocabulary.
- Cross-entropy loss and backprop through time. Theory file 03 covers BPTT. Loss-computation details are Phase 18.
A drill before lab¶
Given a vanilla RNN with \(d = 4\), embedding \(d_\text{embed} = 2\), and the following parameters (all chosen for clean arithmetic):
The embedding of I is \(x_1 = (1, 0)\), of work is \(x_2 = (0, 1)\). Initial state \(h_0 = 0\).
Compute \(h_1\) and \(h_2\). (Use \(\tanh\) honestly; round to 2 decimals.)
If you can reproduce this arithmetic, you understand the recurrence. The rest of Phase 14 is data plumbing and instrumentation.
Next: theory/03-vanishing-gradient.md.