Recurrent Neural Network

RNN 循环神经网络基础

从序列建模、隐藏状态、BPTT 到 LSTM / GRU / BiRNN。RNN 是一种用于处理序列数据的神经网络,它通过隐藏状态把过去的信息传递到当前时间步。

sequence data hidden state parameter sharing BPTT LSTM / GRU
RNN unfolded through time x1 h1 y1 x2 h2 y2 x3 h3 y3 time memory

1. RNN 为什么出现?

普通神经网络很擅长处理固定形状的输入,例如一张图片或一个固定维度向量。但语言、音频、股价、用户行为都有一个共同特征:数据是按时间或顺序展开的,前面的内容会影响后面的判断。

直观例子:只看到 happy 时,模型可能判断为正面;但句子 I am not happy 的关键在于前面的 not。模型需要把前面读到的信息保留下来。

MLP

一次性输入固定长度向量,输入位置之间没有天然的时间记忆。

CNN

擅长捕捉局部模式,例如局部词组或图像纹理,但“记住过去”不是它的核心机制。

RNN

按顺序读取数据,每一步把历史压缩进 hidden state,再传给下一步。

2. 什么是序列数据?

序列数据不是一个孤立样本,而是一串有顺序的元素。顺序一旦改变,含义往往就变了。

数据类型 序列元素 任务示例
文本 word 或 token 情感分类、机器翻译、POS tagging、NER
时间序列 time step 温度预测、股价预测、传感器异常检测
DNA / 蛋白质 nucleotide 或 amino acid 分类、功能预测、结构预测
音频 采样点或语音帧 语音识别、说话人识别、声音分类
用户行为 点击、购买、浏览事件 推荐系统、流失预测、下一步行为预测

3. RNN 的核心思想:隐藏状态 hidden state

RNN cell 在时间步 t 接收当前输入 x_t 和上一个时间步的隐藏状态 h_{t-1},生成当前隐藏状态 h_t。可以把 h_t 理解为“到目前为止模型记住的信息”。

h_t = tanh(W_x x_t + W_h h_{t-1} + b) x_t: 当前输入;h_{t-1}: 上一步记忆;h_t: 当前记忆;W_x/W_h: 权重;b: 偏置;tanh: 激活函数。
current input x_t previous h_{t-1} RNN cell same function at each step current h_t

4. RNN 如何处理一句话?

I am not happy 为例,RNN 不会一次性把整句话扔进去,而是按 token 一个一个读。每一步的 hidden state 都会携带前面读过的信息。

时间步展开演示
Step 1: I读入 I,得到 h1。
Step 2: am结合 h1,得到 h2。
Step 3: not结合 h2,得到 h3;否定信息进入记忆。
Step 4: happy结合 h3,得到 h4;分类时可用 h4。
当前高亮:I → h1
考试关键词:文本分类通常用最后的 h_T 做预测;序列标注通常用每个时间步的 h_t 做预测。

5. RNN 的输入和输出形式

One-to-one

普通输入输出,不是 RNN 的典型用途。

xy

One-to-many

一个输入生成序列,例如图像描述。

xy1 y2 y3

Many-to-one

序列输出一个标签,例如情感分类。

x1 x2 x3label

Many-to-many 同长度

每个输入对应输出,例如 POS tagging。

x1 x2 x3y1 y2 y3

Many-to-many 不同长度

输入输出长度不同,例如机器翻译。

EnglishChinese

6. RNN 中的参数共享

RNN 在每个时间步使用同一组参数:W_xW_hb。因此参数量不会随序列长度线性增长,模型也能处理不同长度的序列。

如果不共享参数

  • 每个时间步都有自己的参数。
  • 序列越长,参数越多。
  • 泛化能力差,训练更不稳定。

共享参数

  • 同一个 RNN cell 被重复使用。
  • 学习通用的序列转移模式。
  • 可以自然处理可变长度输入。
RNN cell W_x, W_h, b RNN cell same params RNN cell same params 同一套参数沿时间展开,不是复制出不同参数

7. RNN 的训练:BPTT

Backpropagation Through Time 的意思是:把循环结构沿时间展开,再像训练一个很深的前馈网络一样反向传播。损失可以来自最后一步,也可以来自每一步。

forward: x1 → h1 → h2 → h3 → y → loss
backward: loss → h3 → h2 → h1 BPTT 本质上是普通反向传播在时间维度上的应用。
考试答法:BPTT unfolds the recurrent network through time and applies backpropagation on the unfolded computational graph. Longer sequences create longer gradient paths, which may cause vanishing or exploding gradients.

8. RNN 的主要问题:梯度消失和梯度爆炸

Vanishing Gradient

梯度在反向传播中不断乘以小于 1 的数,越来越接近 0。早期时间步几乎学不到,长距离依赖难以建模。

0.2 → 0.04 → 0.008 → 0.0016

Exploding Gradient

梯度不断乘以大于 1 的数,数值快速变大,参数更新不稳定,loss 震荡,甚至出现 NaN。

2 → 4 → 8 → 16 → 32

Vanishing vs Exploding Gradient 动画
绿色表示梯度消失,红色表示梯度爆炸。
解决思路:梯度爆炸常用 gradient clipping;梯度消失常用 LSTM / GRU、更好的初始化、残差结构等。

9. 长距离依赖 Long-term Dependency

基础 RNN 理论上可以把很早的信息传到后面,但实际训练中,因为 BPTT 路径太长,早期信息很容易被弱化。

例句:The book that I bought yesterday from the old store near the station was interesting.
主语 book 和谓语 was 距离很远,模型需要记住很久之前的主语信息。
考试提示:如果题目问 why vanilla RNN has difficulty modeling long-term dependencies,通常要回答 BPTT、gradient vanishing/exploding、repeated multiplication through time。

10. LSTM:为什么比普通 RNN 更强?

LSTM = Long Short-Term Memory。它增加了 cell state,像一条信息高速公路,让重要信息更容易跨越很多时间步传递。三个门控制信息如何流动。

Forget gate

决定旧信息保留多少。像删除旧笔记。

Input gate

决定新信息写入多少。像写入新笔记。

Output gate

决定当前展示哪部分内部状态。

f_t = sigmoid(...)
i_t = sigmoid(...)
o_t = sigmoid(...)
C_t = f_t * C_{t-1} + i_t * C~_t f_t 控制遗忘;i_t 控制写入;o_t 控制输出;C_t 是 cell state。
LSTM Cell forget input output cell state highway C_{t-1} h_{t-1}, x_t C_t h_t

11. GRU:LSTM 的简化版本

GRU = Gated Recurrent Unit。它没有单独的 cell state,而是把 hidden state 同时作为记忆,参数比 LSTM 少,训练通常更快,很多任务中效果接近 LSTM。

Update gate z_t

控制保留旧信息和加入新信息的比例。

Reset gate r_t

控制在生成候选记忆时忘记多少过去信息。

z_t = sigmoid(...) // update gate
r_t = sigmoid(...) // reset gate
h_t = (1 - z_t) * h_{t-1} + z_t * h~_t LSTM: 三个门 + cell state;GRU: 两个门 + hidden state,更轻量。

12. Bidirectional RNN / BiLSTM / BiGRU

普通 RNN 只能从左到右读取序列。但很多 NLP 任务需要同时看左侧和右侧上下文。

例句:He went to the bank to deposit money. 为了判断 bank 是“银行”而不是“河岸”,后面的 deposit money 非常关键。
h_t = [forward_h_t ; backward_h_t] 两个方向的 hidden state 拼接,得到更完整的上下文表示。
Forward RNN Backward RNN He bank deposit money

POS taggingNamed Entity Recognition序列标注需要完整上下文的任务

13. RNN 在 NLP 中的典型应用

文本分类

tokens → embeddings → RNN/LSTM/GRU → last hidden state → classifier。

I love this movie → positive

POS Tagging

tokens → embeddings → BiLSTM → classifier at each time step。

I / love / football → PRON / VERB / NOUN

NER

tokens → embeddings → BiLSTM → CRF or classifier。

Barack Obama → B-PER I-PER

机器翻译

早期结构是 Encoder RNN → context vector → Decoder RNN。现在 Transformer 主导,但 RNN 是理解 seq2seq 的基础。

14. PyTorch 中 RNN 的基本用法

下面是一个最小文本分类模型。输入 x 的形状是 [batch_size, seq_len],先变成 embedding,再送入 nn.RNN

import torch
import torch.nn as nn

class RNNClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.RNN(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # x: [batch_size, seq_len]
        embedded = self.embedding(x)
        # embedded: [batch_size, seq_len, embed_dim]
        output, hidden = self.rnn(embedded)
        # output: [batch_size, seq_len, hidden_dim]
        # hidden: [num_layers, batch_size, hidden_dim]
        last_hidden = hidden[-1]
        logits = self.fc(last_hidden)
        return logits
output vs hidden:output 保存每个时间步的 hidden state;hidden 保存最后一个时间步、每一层的 hidden state。文本分类常用 hidden[-1],序列标注常用 output

15. POS Tagging:为什么用 output 而不是 hidden?

POS tagging 是 many-to-many 同长度任务,每个词都需要一个标签。如果只用最后的 hidden state,就只得到整句话的一个表示,无法给每个 token 单独分类。

output: [batch_size, seq_len, hidden_dim]
linear(output): [batch_size, seq_len, num_tags]
output, hidden = self.rnn(embedded)
logits = self.fc(output)  # 每个 token 都有自己的预测结果

16. RNN / LSTM / GRU / Transformer 对比

模型 核心思想 优点 缺点 适用场景
Vanilla RNN hidden state 递归传递 结构简单,适合理解序列模型 长距离依赖差,易梯度消失/爆炸 短序列、教学基础
LSTM cell state + 三个 gates 记忆能力强,缓解长期依赖问题 参数多,计算更慢 较长文本、时间序列
GRU 两个 gates + hidden state 更轻量,训练通常更快 表达能力有时不如 LSTM 资源受限或中等长度序列
BiLSTM / BiGRU 正向和反向 RNN 拼接 同时利用左右上下文 不适合实时预测未来不可见任务 POS、NER、序列标注
Transformer attention 直接建模 token 关系 并行强,长距离建模强 资源需求较高,需要更多数据 现代 NLP、大模型

17. RNN 和 Transformer 的核心区别

RNN

  • 顺序处理,一个时间步接一个时间步。
  • 依靠 hidden state 传递信息。
  • 难以并行,长距离依赖困难。
x1x2x3x4

Transformer

  • 同时处理整个序列。
  • 使用 attention 直接建模 token 之间关系。
  • 更容易捕捉长距离依赖,更适合大规模 NLP。
x1 x2 x3 x4

18. 考试高频问答区

1. What is the main idea of RNN?
RNN processes sequence data step by step and uses hidden state to carry information from previous time steps to the current step.
2. Why can RNN handle variable-length sequences?
Because the same RNN cell and parameters are reused at every time step, so the network can be unrolled for different sequence lengths.
3. What is hidden state?
Hidden state is the memory representation of what the model has read so far. It depends on the current input and the previous hidden state.
4. What is parameter sharing in RNN?
The same weights, such as W_x and W_h, are used across all time steps. This reduces parameters and helps learn general sequence patterns.
5. What is BPTT?
Backpropagation Through Time unfolds the RNN along the time dimension and applies backpropagation on the unfolded graph.
6. Why does vanilla RNN suffer from vanishing gradients?
During BPTT, gradients are repeatedly multiplied through many time steps. If these factors are smaller than 1, gradients shrink toward zero.
7. What is long-term dependency?
It means the prediction at a later time step depends on information from much earlier time steps.
8. How does LSTM solve the weakness of vanilla RNN?
LSTM uses cell state and gates to control remembering, forgetting, and outputting information, which helps preserve long-term information.
9. What is the difference between LSTM and GRU?
LSTM has cell state and three gates; GRU has no separate cell state and usually uses update and reset gates, so it is lighter.
10. Why is BiLSTM useful for POS tagging and NER?
Because these tasks need both left and right context. BiLSTM reads the sequence in both directions and concatenates the two hidden states.
11. What is the difference between output and hidden in PyTorch RNN?
output contains hidden states for all time steps; hidden contains the final hidden state for each layer.
12. Why are Transformers often preferred over RNNs in modern NLP?
Transformers process tokens in parallel and use attention to model long-range dependencies more directly.

19. RNN Exam Cheat Sheet

RNN 用于序列数据。
hidden state 保存历史信息。
h_t = f(x_t, h_{t-1})。
参数在时间步共享。
BPTT 是沿时间展开后的反向传播。
vanilla RNN 容易梯度消失/爆炸。
LSTM 用 cell state 和 gates 解决长期依赖。
GRU 是更轻量的 gated RNN。
BiRNN 同时看左右上下文。
文本分类用最后 hidden。
序列标注用每个 time step 的 output。
Transformer 更常用,但 RNN 是理解序列模型的重要基础。