Skip to content

English · Español

02 — RNN, GRU, LSTM: la recurrencia como máquina de estados

🇪🇸 Una RNN es una función de estado: lee un token, actualiza un vector "memoria", emite una predicción, y repite. Lo importante es lo que se gana (memoria distribuida que generaliza, no como un n-grama) y lo que se pierde (la memoria es de tamaño fijo y el cálculo no se puede paralelizar sobre la secuencia).

Este archivo deriva tres arquitecturas estrechamente relacionadas — RNN vanilla, GRU, LSTM — como una familia de máquinas de estados, cada una un parche sobre el modo de fallo de la anterior.


El marco: una máquina de secuencias de tokens

Un modelo de lenguaje es una función que consume tokens uno a uno y emite una distribución sobre el siguiente token. Un n-grama hace esto con una tabla de cuentas indexada por los \(n - 1\) tokens previos. El n-grama no tiene estado interno; todo lo que sabe del prefijo es la identidad literal de los últimos \(n - 1\) tokens.

Una red neuronal recurrente hace esto con un vector de estado aprendido, continuo, de dimensión fija \(h_t \in \mathbb{R}^d\):

\[ h_t = f_\theta(h_{t-1}, x_t) \]
\[ \hat y_t = g_\theta(h_t) \]

donde: - \(x_t\) es el embedding del \(t\)-ésimo token de entrada (de la Fase 13); - \(h_t\) es el estado oculto — un resumen aprendido del prefijo \(w_1, \ldots, w_t\); - \(f_\theta\) es la función de recurrencia con parámetros \(\theta\); - \(g_\theta\) es la cabeza de salida, mapeando el estado a logits sobre el vocabulario.

El estado oculto tiene dimensión fija \(d\) (típicamente 32–256 para nuestro corpus). Toda la información sobre el prefijo tiene que caber en ese vector. Esta es la restricción definitoria de los modelos recurrentes.

RNN vanilla (Elman, 1990)

La parametrización más simple posible:

\[ h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b_h) \]
\[ \hat y_t = W_{ho} h_t + b_o \]

con parámetros: - \(W_{hh} \in \mathbb{R}^{d \times d}\) — matriz de transición de estado; - \(W_{xh} \in \mathbb{R}^{d \times d_\text{embed}}\) — proyección de entrada; - \(W_{ho} \in \mathbb{R}^{|V| \times d}\) — proyección de salida; - \(b_h, b_o\) — sesgos.

Eso es todo. Cuatro matrices, dos sesgos. Para \(d = 32\), \(|V| = 64\), \(d_\text{embed} = 16\), la cuenta de parámetros es \(32^2 + 32 \cdot 16 + 64 \cdot 32 + 32 + 64 = 1024 + 512 + 2048 + 32 + 64 = 3680\) parámetros. El modelo es genuinamente pequeño.

Forward pass sobre el ejemplo canónico ["I", "work", ",", "you", "work", ",", "he"]:

h_0 = zeros(d)                                  ← estado inicial
x_1 = embed("I")
h_1 = tanh(W_hh @ h_0 + W_xh @ x_1 + b_h)
x_2 = embed("work")
h_2 = tanh(W_hh @ h_1 + W_xh @ x_2 + b_h)
x_3 = embed(",")
h_3 = tanh(W_hh @ h_2 + W_xh @ x_3 + b_h)
...
x_7 = embed("he")
h_7 = tanh(W_hh @ h_6 + W_xh @ x_7 + b_h)
y_hat = W_ho @ h_7 + b_o                        ← logits sobre V; softmax → P(next | prefix)

Si está entrenado, el modelo ha aprendido a extraer del prefijo I work, you work, he la señal que indica "el siguiente token debe ser works". Esa señal vive en \(h_7\).

Este es el modelo entero. Sin attention. Sin multi-head. Sin capas (bueno, una capa; puedes apilar RNNs pero rara vez en profundidad). El forward pass cuesta \(O(T \cdot d^2)\) en tiempo y \(O(d)\) memoria por paso (más la memoria de parámetros, que es constante).

Lo que el estado oculto puede y no puede codificar

El estado oculto \(h_t \in \mathbb{R}^d\) es un cuello de botella de capacidad fija. Para predecir \(w_t\), el modelo tiene acceso solo a \(h_{t-1}\) y \(x_t\). Todo del prefijo \(w_1, \ldots, w_{t-1}\) tiene que vivir en \(h_{t-1}\).

Dos implicaciones:

  1. El estado debe resumir. Con \(d = 32\), un estado puede codificar (de forma laxa) ~32 bits de información sobre el prefijo. Para nuestro corpus, eso es suficiente para codificar el pronombre sujeto, el tiempo verbal, la fase auxiliar — pero no la secuencia literal de 50 tokens previos.
  2. El estado se sobrescribe en cada paso. Cada nuevo \(x_t\) entra y reconfigura \(h_t\). La información de \(h_{t-1}\) que no se refuerza se diluye. Para el prefijo I work, you work, he, para cuando llegamos al token 7, el pronombre I (token 1) ha pasado por 7 aplicaciones de \(W_{hh}\). Si el modelo "recuerda" I depende de si \(W_{hh}\) preservó esa señal — lo que normalmente no hace (el archivo de teoría 03 de la Fase 14 explica por qué).

La primera implicación es una propiedad buena: las representaciones distribuidas, aprendidas, baten a los indicadores dispersos de n-grama en tareas que generalizan. La segunda es una propiedad mala: es la semilla del problema del gradiente desvaneciente y la razón por la que las dependencias de largo alcance son difíciles.

Los dos fallos de las RNN vanilla

Tras ~30 años de intentos, el campo convergió en dos fallos precisos:

  1. Desvanecimiento / explosión del gradiente a lo largo del tiempo. Derivado en theory/03-vanishing-gradient.md. En resumen: el gradiente desde una pérdida en un paso tardío hasta una entrada en un paso temprano fluye a través de la multiplicación repetida por \(W_{hh}\). Si el resultado se desvanece o explota está determinado por los autovalores de \(W_{hh}\). Estabilizar esto es difícil.
  2. Cómputo en serie. \(h_t\) depende de \(h_{t-1}\). No puedes computar \(h_t\) hasta que exista \(h_{t-1}\). A lo largo de una secuencia de longitud \(T\), esta es una cadena inherentemente serial — ninguna cantidad de paralelismo de GPU ayuda. Una capa de attention de un transformer, por contraste, computa las \(T\) salidas en paralelo mediante un único matmul de \(T \times T\).

El primer fallo es lo que motivó LSTM/GRU. El segundo fallo es lo que motivó attention. Nota: LSTM/GRU parchea el fallo 1 pero no hace nada por el fallo 2. Attention parchea ambos.

GRU (Cho et al., 2014)

La GRU (Gated Recurrent Unit) es una modificación de la RNN vanilla que añade dos puertas — redes pequeñas que deciden cuánto del pasado mantener y cuánta información nueva absorber.

Definición:

\[ z_t = \sigma(W_z [h_{t-1}, x_t] + b_z) \quad \text{(puerta de update)} \]
\[ r_t = \sigma(W_r [h_{t-1}, x_t] + b_r) \quad \text{(puerta de reset)} \]
\[ \tilde h_t = \tanh(W [r_t \odot h_{t-1}, x_t] + b) \quad \text{(estado candidato)} \]
\[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde h_t \quad \text{(combinación convexa)} \]

donde \(\sigma\) es la sigmoid, \(\odot\) es la multiplicación elementwise, y \([a, b]\) denota concatenación a lo largo del eje de características.

El cambio estructural clave es la última línea. En lugar de sobrescribir \(h_{t-1}\) con \(\tanh(...)\), la GRU forma una combinación convexa del estado viejo y un candidato nuevo. Si \(z_t \approx 0\), el nuevo estado es básicamente el viejo; si \(z_t \approx 1\), es el candidato nuevo.

Por qué esto importa. La recurrencia de la RNN vanilla es multiplicativa — cada paso es "multiplica matricialmente por \(W_{hh}\) luego no-linealidad". La multiplicación repetida contrae o expande señales (la historia del gradiente desvaneciente/explosivo). La recurrencia de la GRU tiene un camino aditivo que deja que la información fluya a lo largo del tiempo sin ser multiplicada por nada cuando \(z_t \approx 0\). Los gradientes pueden fluir hacia atrás por ese camino sin contracción.

Cuenta de parámetros. Una GRU tiene tres matrices de pesos cada una de forma \(d \times (d + d_\text{embed})\), más sesgos, así que \(\sim 3 d (d + d_\text{embed})\) parámetros. Para \(d = 32, d_\text{embed} = 16\), son \(3 \cdot 32 \cdot 48 = 4608\) parámetros — aproximadamente \(1{,}25\times\) la RNN vanilla. El coste del parche es modesto.

Intuición trabajada sobre nuestro corpus. Cuando la RNN ve I work, you work, he, la GRU puede aprender a fijar \(z_t \approx 0\) al procesar you work, he de modo que la señal del pronombre-sujeto de I (codificada en \(h_1\)) se propague hacia adelante con decaimiento mínimo. La puerta de reset \(r_t\) similarmente deja que el modelo decida cuándo "olvidar" el sujeto previo (p. ej., al ver un separador). Si el modelo realmente aprende esto es una cuestión de entrenamiento (Fase 18); la GRU al menos lo hace aprendible.

LSTM (Hochreiter & Schmidhuber, 1997)

La LSTM (Long Short-Term Memory) es la prima mayor y más elaborada de la GRU. Introduce un estado de celda separado \(c_t\) junto al estado oculto \(h_t\), con tres puertas en lugar de dos.

Esbozo (no derivamos el backward pass):

\[ f_t = \sigma(W_f [h_{t-1}, x_t] + b_f) \quad \text{(puerta forget)} \]
\[ i_t = \sigma(W_i [h_{t-1}, x_t] + b_i) \quad \text{(puerta input)} \]
\[ o_t = \sigma(W_o [h_{t-1}, x_t] + b_o) \quad \text{(puerta output)} \]
\[ \tilde c_t = \tanh(W_c [h_{t-1}, x_t] + b_c) \quad \text{(celda candidata)} \]
\[ c_t = f_t \odot c_{t-1} + i_t \odot \tilde c_t \quad \text{(actualización de celda — el famoso camino aditivo)} \]
\[ h_t = o_t \odot \tanh(c_t) \quad \text{(salida)} \]

La actualización del estado de celda es el corazón del LSTM. Es una combinación convexa del estado de celda previo y un candidato nuevo — misma idea que la GRU. El camino \(c_{t-1} \to c_t\) no tiene multiplicación matricial (solo gating elementwise), así que los gradientes fluyen por él sin contracción.

LSTM vs GRU. Empíricamente, rinden dentro de un pequeño porcentaje uno del otro en la mayoría de tareas. LSTM tiene más parámetros (cuatro matrices de pesos vs tres) y una puerta más (la puerta forget explícita vs el acoplado 1 - z_t de la GRU). LSTM es el estándar más antiguo; GRU es más simple y a menudo preferida cuando quieres una baseline recurrente.

Implementamos la RNN vanilla y la GRU en el laboratorio de la Fase 14. La LSTM se esboza solo en teoría — sus matemáticas son una página de trabajo, pero entrenarla y probarla duplicaría lo que la GRU ya muestra.

Lo que los modelos recurrentes hacen bien, sobre nuestro corpus

Tres cosas:

  1. Consistencia local. I work lleva consistentemente a ciertas continuaciones; he lleva consistentemente a otras. Una RNN aprende estas regularidades en los embeddings y la recurrencia simultáneamente, compartiendo representaciones entre patrones (a diferencia de un n-grama).
  2. Generalización suave a combinaciones no vistas. Si el set de entrenamiento tiene I work, you work, he works y I play, you play, ... pero no he plays, una RNN con embeddings compartidos tiene una oportunidad de predecir plays correctamente porque su representación de play está cerca de work en el espacio de embeddings. Un n-grama tiene cero oportunidad.
  3. Inferencia en memoria constante. El estado de una RNN es de tamaño fijo independientemente de la longitud de la secuencia. El KV cache del transformer (Fase 22) crece linealmente con la longitud de secuencia. Por eso la gente está revisitando ideas recurrentes (Mamba, RWKV) para contextos muy largos. Lo mencionamos de pasada; territorio de la Fase 36.

Lo que los modelos recurrentes hacen mal, sobre nuestro corpus y en general

  1. Dependencias de largo alcance. El simple future he is going to work es una cadena de 4 tokens. Para cuando el modelo está generando work, debe haber recordado going to. Una RNN con \(d = 32\) e inicialización naive perderá esta señal en ~10 pasos debido a los gradientes desvanecientes. La GRU parchea esto algo, pero solo entrenando las puertas para preservar la señal — lo que en sí mismo requiere que los gradientes fluyan hacia atrás por muchos pasos.
  2. Transferencia entre paradigmas. "I work / yo trabajo" → "I worked / yo trabajé". El modelo tiene que aprender que el sufijo -ed en inglés corresponde a la terminación en español para verbos -ar. Una RNN puede aprender esto si el corpus muestra suficientes ejemplos, pero no tiene un prior arquitectónico hacia tales alineamientos — tienen que ser descubiertos en el espacio de embeddings + pesos de recurrencia.
  3. Incapacidad de paralelizar sobre el eje de secuencia. Este es el asesino para escalar. Un documento de 1000 tokens pasa por 1000 pasos secuenciales de RNN; nada puede paralelizar esto. Los transformers lo hacen en un único matmul sobre la secuencia entera (Fase 15).

Por qué aún los implementamos

Dos razones:

  1. El forward pass es mecánicamente esclarecedor. Cuando observas evolucionar \(h_t\) token a token sobre I work, you work, he, puedes ver (de forma laxa) cómo el estado codifica "estamos ahora en el tercer pronombre, post-separador, post-bigrama-work" — o no, dependiendo de la inicialización. Esta es una sensación que no puedes obtener de un n-grama (sin estado) o de un transformer (el estado es el contexto entero, difícil de leer de un golpe).
  2. La Fase 18 necesita una baseline. Necesitamos un número baseline real contra el que comparar el Mini-GPT entrenado. El n-grama de teoría 01 es una baseline; los logits de una RNN sin entrenar dan otra baseline (aleatoria); una RNN entrenada daría la comparación más fuerte. La Fase 14 se detiene en "solo forward pass" — la Fase 18 podría opcionalmente entrenar una RNN para comparación completa, pero la especificación dice que no.

Lo que esta fase NO cubre

  • RNNs bidireccionales. Una BiRNN procesa la secuencia tanto de izquierda a derecha como de derecha a izquierda. Útil para tareas de etiquetado, irrelevante para modelado de lenguaje (no vemos el futuro al predecir el siguiente token).
  • RNNs apiladas / multi-capa. Apilar RNNs (la salida de una se convierte en entrada de la siguiente) es directo pero introduce desvanecimiento en profundidad encima del desvanecimiento en tiempo. Fuera del alcance.
  • Teacher forcing vs muestreo durante entrenamiento. Territorio de la Fase 18.
  • Revivificaciones recurrentes modernas (Mamba, S4, RWKV, RetNet). Fase 36 (arquitecturas de frontera). Comparten la idea de "recurrencia lineal + estado selectivo", que es conceptualmente aguas arriba del LSTM pero usa matemáticas muy diferentes. Mencionado por vocabulario.
  • Pérdida de entropía cruzada y backprop a lo largo del tiempo. El archivo de teoría 03 cubre BPTT. Los detalles de cómputo de la pérdida son de la Fase 18.

Un ejercicio antes del laboratorio

Dada una RNN vanilla con \(d = 4\), embedding \(d_\text{embed} = 2\), y los siguientes parámetros (todos elegidos para aritmética limpia):

\[ W_{hh} = 0.5 I_4, \quad W_{xh} = \begin{pmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \\ -1 & 1 \end{pmatrix}, \quad b_h = 0 \]

El embedding de I es \(x_1 = (1, 0)\), de work es \(x_2 = (0, 1)\). Estado inicial \(h_0 = 0\).

Computa \(h_1\) y \(h_2\). (Usa \(\tanh\) honradamente; redondea a 2 decimales.)

\[ h_1 = \tanh(0.5 \cdot 0 + W_{xh} \cdot (1, 0)^\top) = \tanh((1, 0, 1, -1)^\top) = (0.76, 0, 0.76, -0.76)^\top \]
\[ h_2 = \tanh(0.5 \cdot h_1 + W_{xh} \cdot (0, 1)^\top) = \tanh((0.38 + 0, 0 + 1, 0.38 + 1, -0.38 + 1)^\top) = \tanh((0.38, 1, 1.38, 0.62)^\top) \approx (0.36, 0.76, 0.88, 0.55)^\top \]

Si puedes reproducir esta aritmética, entiendes la recurrencia. El resto de la Fase 14 es plomería de datos e instrumentación.


Siguiente: theory/03-vanishing-gradient.md.