Many-to-many (in sync) · e.g., POS tagging — one label per input token.
Many-to-many (encoder-decoder) · e.g., translation — input sequence, output sequence, different lengths. This is the Seq2Seq pattern covered in L11.
Four shapes, same cell.
class RNNCell(nn.Module):
def __init__(self, d_in, d_h):
super().__init__()
self.W = nn.Linear(d_in, d_h, bias=False)
self.U = nn.Linear(d_h, d_h, bias=True)
# tanh folds the output back to [-1, 1]
def forward(self, x_t, h_prev):
return torch.tanh(self.W(x_t) + self.U(h_prev))
# Unrolled loop
h = torch.zeros(batch, d_h)
for t in range(seq_len):
h = cell(x[:, t], h)
nn.RNN, nn.LSTM, nn.GRU handle the loop + CUDA kernels for you.
Same problem as depth, now along the time axis
Analogy · Chinese whispers. Tell a secret to person 1. They whisper to person 2, etc. By person 20, the message is garbled (vanished) or wildly distorted (exploded).
The gradient is that secret. It travels backward from the loss to the start of the sequence, and at each step it gets transformed.
Start with the simplest possible RNN:
Chain rule:
For sequence length
Add
A product of
Consider a tiny RNN with scalar state ·
That's why vanilla RNNs can't learn dependencies across more than ~20 timesteps. Every step in the product pulls the gradient toward zero if
For long sequences (thousands of steps), full BPTT is expensive.
Truncated BPTT (TBPTT) — only backpropagate
for chunk in sequence.split(K, dim=1):
h_detached = h.detach() # cut gradient here
for t in range(chunk.size(1)):
h = cell(chunk[:, t], h)
loss = criterion(h, target)
loss.backward()
opt.step()
Typical
Exploding gradients are worse than vanishing — one bad step can destroy weeks of training. Clip by global norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
opt.step()
Pascanu et al. 2013 showed clipping at norm ~1 makes RNN training robust. Still the default for any sequence model — RNN, LSTM, Transformer. Cheap insurance against numerical catastrophe.
Three sigmoid gates protect a cell state
Interactive: drag forget/input/output sliders; see the cell state freeze, flow, or reset — lstm-gates.
A vanilla RNN treats every input the same. The LSTM has three "specialists":
The next slide gives the math · keep these roles in mind as you read the equations.
A vanilla RNN crams everything into one memory
LSTM adds a separate memory conveyor belt · the cell state
A protected long-term memory + learned controllers for write / forget / read. That's the whole idea.
Combine inputs into one control vector
Step 1 · Forget gate (sigmoid →
0 → forget · 1 → keep.
Step 2 · Input gate + candidate:
Step 3 · Update the cell state · throw out old, add new:
The crucial + is what makes gradients flow.
Step 4 · Output gate + hidden state:
1D state. Setup:
Network has learned (for this input):
Compute new cell state:
The memory flipped from large positive (plural) to negative (singular) in one step — exactly because the forget gate was small and the input gate was large.
Forget gate
Input gate
Output gate
The LSTM's "memory" is the cell state
Vanilla RNN.
A matrix multiplication at every step → product of matrices → vanishing/exploding.
LSTM cell state.
That's it — no matrix multiplication. Just element-wise multiplication by the forget gate.
If the network learns
Suppose memory must be preserved → forget gate trained to
| After 100 steps | |
|---|---|
| LSTM along cell state | |
| Vanilla RNN (optimistic factor 0.5) |
LSTM signal survives. RNN signal is below floating-point precision.
This is the same idea as ResNet skip connections, in the time dimension.
Fewer gates, comparable accuracy
Two design questions Cho et al. asked in 2014:
Result · the update gate
Step 1 · Reset gate. How much of the past to use when forming the candidate.
Step 2 · Candidate. Reset gate filters the past first.
Step 3 · Update gate. How much new vs old to use.
Step 4 · Final state · linear interpolation.
The additive structure (like LSTM's cell state) is what keeps gradients flowing.
1D state.
Case 1 ·
Close to old state — input mostly ignored.
Case 2 ·
Close to candidate — state nearly fully replaced.
The single update gate gives the network smooth control between "preserve" and "overwrite."
In the late 2010s, ML papers often included an ablation · "we tried LSTM and GRU and picked whichever worked." By 2019, most groups defaulted to whichever they had better library support for.
That casualness told the story · the gating trick matters; which gates you pick doesn't much. Any additive-gated recurrence works; the architectural variants are micro-optimizations on the core idea.
| LSTM | GRU | |
|---|---|---|
| Gates | 3 + candidate | 2 + candidate |
| State | cell + hidden | hidden only |
| Params (d_h = 128) | 4 · 128 · 256 = 131k | 3 · 128 · 256 = 98k |
| Accuracy | baseline | often tied |
| Training speed | slower | ~15% faster |
Empirically close — both far beyond vanilla RNNs on long-range tasks. GRU was often preferred pre-Transformer; today either is fine for the few remaining RNN use cases.
Bidirectional · run one RNN left-to-right and another right-to-left; concatenate the outputs. Used for classification/tagging (not generation).
self.rnn = nn.LSTM(d_in, d_h, num_layers=2, bidirectional=True, batch_first=True)
Stacked (deep) RNN · layer
Both tricks combinable: 2-layer bidirectional LSTMs were the standard NLP architecture from 2015–2017 (before Transformers).
Transformers have largely replaced RNNs for:
But RNNs are still the right choice for:
Streaming / online inference
Tiny devices
Starting in 2023, a new class of models has re-emerged · state-space models and linear RNNs that match Transformer quality with O(1) inference per token.
The story isn't "RNNs are dead" — it's "vanilla RNNs with sequential gradients couldn't scale." Modern parallelizable RNNs are a quiet comeback. Watch this space.
Consider translating:
"The animal didn't cross the street because it was too tired."
To translate "it" correctly, the model must look back to "animal" — maybe 6 tokens ago.
An RNN compresses all of that into a single
The next lecture (L11) examines encoder-decoder Seq2Seq, which also struggles with this fixed-length bottleneck. That struggle motivates attention (L12) — the idea that finally let sequence models scale.