1. RNN 为什么出现?
普通神经网络很擅长处理固定形状的输入,例如一张图片或一个固定维度向量。但语言、音频、股价、用户行为都有一个共同特征:数据是按时间或顺序展开的,前面的内容会影响后面的判断。
happy 时,模型可能判断为正面;但句子
I am not happy 的关键在于前面的
not。模型需要把前面读到的信息保留下来。
MLP
一次性输入固定长度向量,输入位置之间没有天然的时间记忆。
CNN
擅长捕捉局部模式,例如局部词组或图像纹理,但“记住过去”不是它的核心机制。
RNN
按顺序读取数据,每一步把历史压缩进 hidden state,再传给下一步。
2. 什么是序列数据?
序列数据不是一个孤立样本,而是一串有顺序的元素。顺序一旦改变,含义往往就变了。
| 数据类型 | 序列元素 | 任务示例 |
|---|---|---|
| 文本 | word 或 token | 情感分类、机器翻译、POS tagging、NER |
| 时间序列 | time step | 温度预测、股价预测、传感器异常检测 |
| DNA / 蛋白质 | nucleotide 或 amino acid | 分类、功能预测、结构预测 |
| 音频 | 采样点或语音帧 | 语音识别、说话人识别、声音分类 |
| 用户行为 | 点击、购买、浏览事件 | 推荐系统、流失预测、下一步行为预测 |
4. RNN 如何处理一句话?
以 I am not happy 为例,RNN
不会一次性把整句话扔进去,而是按 token 一个一个读。每一步的 hidden
state 都会携带前面读过的信息。
h_T 做预测;序列标注通常用每个时间步的
h_t 做预测。
5. RNN 的输入和输出形式
One-to-one
普通输入输出,不是 RNN 的典型用途。
One-to-many
一个输入生成序列,例如图像描述。
Many-to-one
序列输出一个标签,例如情感分类。
Many-to-many 同长度
每个输入对应输出,例如 POS tagging。
Many-to-many 不同长度
输入输出长度不同,例如机器翻译。
7. RNN 的训练:BPTT
Backpropagation Through Time 的意思是:把循环结构沿时间展开,再像训练一个很深的前馈网络一样反向传播。损失可以来自最后一步,也可以来自每一步。
backward: loss → h3 → h2 → h1 BPTT 本质上是普通反向传播在时间维度上的应用。
8. RNN 的主要问题:梯度消失和梯度爆炸
Vanishing Gradient
梯度在反向传播中不断乘以小于 1 的数,越来越接近 0。早期时间步几乎学不到,长距离依赖难以建模。
0.2 → 0.04 → 0.008 → 0.0016
Exploding Gradient
梯度不断乘以大于 1 的数,数值快速变大,参数更新不稳定,loss 震荡,甚至出现 NaN。
2 → 4 → 8 → 16 → 32
9. 长距离依赖 Long-term Dependency
基础 RNN 理论上可以把很早的信息传到后面,但实际训练中,因为 BPTT 路径太长,早期信息很容易被弱化。
主语
book 和谓语
was 距离很远,模型需要记住很久之前的主语信息。
10. LSTM:为什么比普通 RNN 更强?
LSTM = Long Short-Term Memory。它增加了
cell state,像一条信息高速公路,让重要信息更容易跨越很多时间步传递。三个门控制信息如何流动。
Forget gate
决定旧信息保留多少。像删除旧笔记。
Input gate
决定新信息写入多少。像写入新笔记。
Output gate
决定当前展示哪部分内部状态。
i_t = sigmoid(...)
o_t = sigmoid(...)
C_t = f_t * C_{t-1} + i_t * C~_t f_t 控制遗忘;i_t 控制写入;o_t 控制输出;C_t 是 cell state。
11. GRU:LSTM 的简化版本
GRU = Gated Recurrent Unit。它没有单独的 cell state,而是把 hidden state 同时作为记忆,参数比 LSTM 少,训练通常更快,很多任务中效果接近 LSTM。
Update gate z_t
控制保留旧信息和加入新信息的比例。
Reset gate r_t
控制在生成候选记忆时忘记多少过去信息。
r_t = sigmoid(...) // reset gate
h_t = (1 - z_t) * h_{t-1} + z_t * h~_t LSTM: 三个门 + cell state;GRU: 两个门 + hidden state,更轻量。
12. Bidirectional RNN / BiLSTM / BiGRU
普通 RNN 只能从左到右读取序列。但很多 NLP 任务需要同时看左侧和右侧上下文。
bank 是“银行”而不是“河岸”,后面的
deposit money 非常关键。
POS taggingNamed Entity Recognition序列标注需要完整上下文的任务
13. RNN 在 NLP 中的典型应用
文本分类
tokens → embeddings → RNN/LSTM/GRU → last hidden state → classifier。
I love this movie → positive
POS Tagging
tokens → embeddings → BiLSTM → classifier at each time step。
I / love / football → PRON / VERB / NOUN
NER
tokens → embeddings → BiLSTM → CRF or classifier。
Barack Obama → B-PER I-PER
机器翻译
早期结构是 Encoder RNN → context vector → Decoder RNN。现在 Transformer 主导,但 RNN 是理解 seq2seq 的基础。
14. PyTorch 中 RNN 的基本用法
下面是一个最小文本分类模型。输入 x 的形状是
[batch_size, seq_len],先变成 embedding,再送入
nn.RNN。
import torch
import torch.nn as nn
class RNNClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.rnn = nn.RNN(
input_size=embed_dim,
hidden_size=hidden_dim,
batch_first=True
)
self.fc = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
# x: [batch_size, seq_len]
embedded = self.embedding(x)
# embedded: [batch_size, seq_len, embed_dim]
output, hidden = self.rnn(embedded)
# output: [batch_size, seq_len, hidden_dim]
# hidden: [num_layers, batch_size, hidden_dim]
last_hidden = hidden[-1]
logits = self.fc(last_hidden)
return logits
output 保存每个时间步的 hidden state;hidden
保存最后一个时间步、每一层的 hidden state。文本分类常用
hidden[-1],序列标注常用 output。
15. POS Tagging:为什么用 output 而不是 hidden?
POS tagging 是 many-to-many 同长度任务,每个词都需要一个标签。如果只用最后的 hidden state,就只得到整句话的一个表示,无法给每个 token 单独分类。
linear(output): [batch_size, seq_len, num_tags]
output, hidden = self.rnn(embedded)
logits = self.fc(output) # 每个 token 都有自己的预测结果
16. RNN / LSTM / GRU / Transformer 对比
| 模型 | 核心思想 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|---|
| Vanilla RNN | hidden state 递归传递 | 结构简单,适合理解序列模型 | 长距离依赖差,易梯度消失/爆炸 | 短序列、教学基础 |
| LSTM | cell state + 三个 gates | 记忆能力强,缓解长期依赖问题 | 参数多,计算更慢 | 较长文本、时间序列 |
| GRU | 两个 gates + hidden state | 更轻量,训练通常更快 | 表达能力有时不如 LSTM | 资源受限或中等长度序列 |
| BiLSTM / BiGRU | 正向和反向 RNN 拼接 | 同时利用左右上下文 | 不适合实时预测未来不可见任务 | POS、NER、序列标注 |
| Transformer | attention 直接建模 token 关系 | 并行强,长距离建模强 | 资源需求较高,需要更多数据 | 现代 NLP、大模型 |
17. RNN 和 Transformer 的核心区别
RNN
- 顺序处理,一个时间步接一个时间步。
- 依靠 hidden state 传递信息。
- 难以并行,长距离依赖困难。
Transformer
- 同时处理整个序列。
- 使用 attention 直接建模 token 之间关系。
- 更容易捕捉长距离依赖,更适合大规模 NLP。