在自然语言处理领域,处理变长序列数据一直是个核心挑战。传统RNN(循环神经网络)通过隐藏状态传递历史信息,但存在梯度消失和并行化困难两大痛点。2017年Google提出的Transformer架构彻底改变了这一局面,其核心创新正是Self-Attention机制。
我最初接触Self-Attention时最震撼的是它的"全连接视野"——每个词元可以直接关注序列中任何位置的词元,不受距离限制。这与人类阅读时"回看前文"或"跳读关键词"的认知方式高度吻合。实际在机器翻译任务中,这种特性让模型能精准捕捉跨长距离的代词指代关系(如"it"指代前文20个词之前的某个名词)。
Self-Attention的核心计算可分为五个步骤:
输入表示:假设输入序列包含n个词元,每个词元xi通过嵌入层转换为d维向量(通常d=512)。最终得到输入矩阵X ∈ R^(n×d)
线性变换:通过可学习的权重矩阵WQ, WK, WV ∈ R^(d×d_k)生成Query、Key、Value矩阵:
python复制Q = X @ WQ # Query矩阵 (n×d_k)
K = X @ WK # Key矩阵 (n×d_k)
V = X @ WV # Value矩阵 (n×d_v)
注意力分数计算:通过QK^T得到原始注意力分数,再除以√d_k进行缩放(防止softmax梯度消失):
python复制scores = Q @ K.T / sqrt(d_k) # (n×n)
Softmax归一化:对每行进行softmax得到注意力权重:
python复制attn_weights = softmax(scores, dim=-1) # (n×n)
加权求和:用注意力权重对Value矩阵加权求和:
python复制output = attn_weights @ V # (n×d_v)
关键细节:实际实现中会采用多头注意力(Multi-Head Attention),即将Q/K/V拆分为h个头分别计算后再拼接。这使模型能同时关注不同表示子空间的信息。
通过对比实验可以清晰看到Self-Attention的优势:
| 特性 | RNN | Self-Attention |
|---|---|---|
| 长距离依赖 | 梯度消失严重 | 直接全局访问 |
| 并行计算 | 序列依赖 | 完全并行 |
| 计算复杂度 | O(n) | O(n²) |
| 路径长度(信息传递) | O(n) | O(1) |
虽然理论复杂度更高,但由于GPU并行优势,实际训练速度反而更快。我在WMT英德翻译任务中测试,Transformer比LSTM快3倍的同时BLEU值提升2.4。
Transformer采用经典编码器-解码器结构,但每个组件都经过重新设计:
编码器层(N=6层):
解码器层(N=6层):
位置编码(Positional Encoding):
由于Self-Attention本身没有位置信息,需要通过正弦/余弦函数注入绝对位置信息:
python复制PE(pos,2i) = sin(pos/10000^(2i/d_model))
PE(pos,2i+1) = cos(pos/10000^(2i/d_model))
我在实现时发现两个细节:
层归一化(LayerNorm)放置位置:
原始论文采用"Post-LN"(在残差连接后),但后续研究发现"Pre-LN"(在子层前)更利于训练深层网络:
python复制# Pre-LN实现示例
x = x + self.dropout(self.attention(self.norm1(x)))
学习率调度:
采用带热启动的逆平方根调度器效果最佳:
python复制lr = d_model^-0.5 * min(step^-0.5, step*warmup^-1.5)
典型参数:warmup_steps=4000
标签平滑(Label Smoothing):
缓解模型过度自信,提升泛化能力:
python复制smoothed_labels = (1-ε)*one_hot + ε/K
其中ε=0.1,K为词汇表大小
问题1:训练初期loss震荡严重
问题2:验证集BLEU不升反降
问题3:长序列生成质量差
python复制score = logP(y|x) + α*length_penalty
典型值:α=0.6(长度惩罚系数)当处理超长序列(如文档级文本)时,标准Self-Attention的O(n²)复杂度成为瓶颈。以下是几种优化方案:
我在法律文书分析中使用Longformer(局部+全局注意力),成功将8000token文档的处理时间从45s降至3.2s。
Transformer在视觉、语音等领域的应用也展现出强大潜力:
视觉Transformer(ViT):
语音Transformer:
在自定义的医疗影像数据集上,ViT相比ResNet-50将肺结节检测F1-score从0.82提升到0.87。