在自然语言处理任务中,词与词之间的关系往往具有多维度的特性。以简单的句子"我爱你"为例:
单头注意力机制(Single-Head Attention)在处理这种复杂关系时存在明显局限。它只能通过单一的注意力权重分布来捕捉所有语义关系,相当于让一个模型同时学习多种完全不同的匹配模式。这就像让一个人同时处理多项需要完全不同思维方式的任务——结果往往是每项任务都难以达到最佳效果。
多头注意力(Multi-Head Attention)的创新之处在于,它将高维特征空间划分为多个子空间,每个子空间可以独立学习不同的注意力模式。这种设计带来了三个关键优势:
首先明确多头注意力中的关键参数:
多头注意力的计算可以分解为五个关键步骤:
将输入序列$X \in \mathbb{R}^{L \times d_{model}}$通过三个可学习的投影矩阵转换为查询(Q)、键(K)、值(V)表示:
$$
Q = XW^Q, \quad K = XW^K, \quad V = XW^V
$$
其中$W^Q, W^K, W^V \in \mathbb{R}^{d_{model} \times d_{model}}$。
将Q、K、V矩阵按头的数量$h$进行拆分:
$$
Q = [Q_1, Q_2, ..., Q_h], \quad K = [K_1, K_2, ..., K_h], \quad V = [V_1, V_2, ..., V_h]
$$
每个$Q_i, K_i, V_i \in \mathbb{R}^{L \times d_k}$。
每个头独立计算缩放点积注意力:
$$
\text{head}_i = \text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)V_i
$$
将所有头的输出在特征维度拼接:
$$
\text{Concat} = [\text{head}_1, \text{head}_2, ..., \text{head}_h]
$$
通过可学习矩阵$W^O \in \mathbb{R}^{d_{model} \times d_{model}}$进行线性变换:
$$
\text{MultiHead}(Q,K,V) = \text{Concat} \cdot W^O
$$
为了更好地理解多头注意力的工作原理,我们通过一个具体的计算示例来演示整个过程。
考虑输入序列["我", "爱", "你"],设置参数:
输入矩阵(已包含词嵌入和位置编码):
$$
X = \begin{bmatrix}
0.5 & 1.1 & 0.2 & 0.1 \
1.0 & 1.1 & 0.3 & 0.2 \
1.2 & -0.2 & 0.1 & 0.4 \
\end{bmatrix}
$$
为简化计算,设所有投影矩阵为单位矩阵$I$。
输入:
$$
Q_1 = K_1 = V_1 = \begin{bmatrix}
0.5 & 1.1 \
1.0 & 1.1 \
1.2 & -0.2 \
\end{bmatrix}
$$
计算注意力分数(以第一个词"我"为例):
输入:
$$
Q_2 = K_2 = V_2 = \begin{bmatrix}
0.2 & 0.1 \
0.3 & 0.2 \
0.1 & 0.4 \
\end{bmatrix}
$$
类似计算可得:
$$
\text{head}_2 ≈ \begin{bmatrix}
0.201 & 0.233 \
0.201 & 0.242 \
0.199 & 0.237 \
\end{bmatrix}
$$
将两个头的输出在特征维度拼接:
$$
\text{Output} = \begin{bmatrix}
0.847 & 0.872 & 0.201 & 0.233 \
0.874 & 0.843 & 0.201 & 0.242 \
0.986 & 0.498 & 0.199 & 0.237 \
\end{bmatrix}
$$
为了验证我们的数学推导,我们实现了一个简化版的多头注意力模块,并与PyTorch官方实现进行对比。
python复制import torch
import torch.nn as nn
class SimpleMultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 投影矩阵
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.shape
# 1. 输入投影
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
# 2. 拆分多头
q = q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
k = k.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
v = v.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# 3. 计算注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k))
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
# 4. 拼接输出
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.w_o(out)
# 测试
d_model = 4
n_heads = 2
x = torch.tensor([[
[0.5, 1.1, 0.2, 0.1],
[1.0, 1.1, 0.3, 0.2],
[1.2, -0.2, 0.1, 0.4]
]], dtype=torch.float32)
# 自定义实现
model = SimpleMultiHeadAttention(d_model, n_heads)
with torch.no_grad():
for param in model.parameters():
param.data = torch.eye(d_model) # 设为单位矩阵
print("自定义实现输出:")
print(model(x))
# PyTorch官方实现
official_mha = nn.MultiheadAttention(d_model, n_heads, batch_first=True, bias=False)
with torch.no_grad():
official_mha.in_proj_weight.data = torch.cat([torch.eye(d_model)]*3, dim=0)
official_mha.out_proj.weight.data = torch.eye(d_model)
out, _ = official_mha(x, x, x)
print("\n官方实现输出:")
print(out)
运行结果显示两种实现的输出完全一致,验证了我们推导的正确性。
多头注意力机制相比单头注意力具有多方面的优势:
每个注意力头相当于一个独立的"专家",可以专注于学习特定类型的语义关系。在我们的例子中:
这种分工使得模型能够并行处理多种不同类型的语言现象。
在高维空间中,随机向量的点积会趋向于相同值,这使得注意力分数难以区分。通过将高维计算分解为多个低维子空间的计算,多头注意力有效缓解了这个问题。
多头注意力的参数量与单头注意力基本相同(仅多了一个最终的投影矩阵$W^O$),但表达能力却显著增强。这相当于用相同的计算成本获得了更强大的建模能力。
在实际的NLP任务中,多头注意力已被证明能够:
在实际实现多头注意力时,有几个关键点需要注意:
投影矩阵$W^Q, W^K, W^V$的初始化对模型性能有重要影响。通常采用:
在处理变长序列或特定任务(如解码器自注意力)时,需要正确应用注意力掩码:
对于长序列,标准的注意力计算复杂度为$O(L^2)$,可以采用以下优化:
多头注意力已成为现代深度学习的基础组件,衍生出多种改进版本:
原始Transformer使用绝对位置编码,后续工作提出了相对位置编码,能更好处理位置关系:
降低计算复杂度的变体:
通过可视化不同头的注意力模式,可以直观理解模型的工作机制:
这种可视化不仅是理解模型的有力工具,也是调试模型性能的重要手段。