注意力机制最初来源于人类视觉系统的启发。当我们观察一个复杂场景时,大脑会本能地聚焦于某些关键区域而忽略其他次要信息。这种选择性关注的能力被抽象化为机器学习中的注意力机制,其数学本质可以理解为一种动态权重分配策略。
在传统RNN结构中,序列处理存在明显的局限性:无论当前处理的内容是否需要历史信息的支持,网络都必须机械地按顺序处理所有先前的token。这就像要求一个人在阅读文章时,必须逐字回忆之前读过的所有内容才能理解当前句子——显然不符合人类的认知方式。
注意力机制通过三个关键向量实现了信息筛选的智能化:
这种设计使得模型可以动态决定哪些历史信息与当前计算相关。例如在处理"I arrived at the bank after crossing the river"这句话时,当模型处理"bank"这个词,注意力机制会自动提高"river"的权重,帮助确定这里指的是"河岸"而非"银行"。
标准的缩放点积注意力(Scaled Dot-Product Attention)计算公式为:
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$
其中$d_k$是键向量的维度。这个$\sqrt{d_k}$的缩放因子非常关键——当维度较高时,点积的结果会变得非常大,将softmax函数推入梯度极小的区域。通过缩放保持梯度稳定,作者发现这对训练深度Transformer至关重要。
实际实现时,这些计算会被批量处理为矩阵运算。假设我们有一个包含4个单词的序列,每个词的嵌入维度是512,那么典型的计算流程:
多头注意力(Multi-Head Attention)是Transformer最具创新性的设计之一。其核心思想是:
数学表达式为:
$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O
$$
其中每个头的计算为:
$$
\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
$$
这种设计允许模型在不同的表示子空间中学习不同的关注模式。例如在处理语言时,某些头可能专门关注句法关系,另一些头则关注语义关联。实验表明,不同的头确实会自发地发展出不同的关注模式。
自注意力(Self-Attention)是指Q、K、V都来自同一输入序列的情况。在Transformer的编码器中,这种机制允许每个位置直接关注输入序列中的所有位置,从而捕获长距离依赖关系。
一个关键特性是自注意力的排列等变性——改变输入序列的顺序只会相应改变输出的顺序,不会影响内容。这与CNN的平移等变性有本质区别。
在解码器中,除了自注意力层外,还存在编码器-解码器注意力层。这里的Q来自解码器的前一层的输出,而K、V来自编码器的最终输出。这种设计使得解码器在生成每个token时,可以动态地关注输入序列中最相关的部分。
在机器翻译任务中,可以观察到解码器生成目标语言单词时,会自动将高注意力权重分配给源语言中对应的单词或短语,形成清晰的对齐关系。
原始Transformer的注意力计算复杂度是序列长度的平方级($O(n^2)$),这对长序列处理构成了挑战。后续研究提出了多种改进:
原始Transformer使用绝对位置编码,将位置信息直接加到输入嵌入中。后续提出的相对位置编码考虑了token之间的相对距离:
$$
e_{ij} = \frac{(x_i + p_i)W^Q((x_j + p_j)W^K)^T}{\sqrt{d_k}}
$$
改进为:
$$
e_{ij} = \frac{x_iW^Q(x_jW^K + a_{ij}^K)^T}{\sqrt{d_k}}
$$
其中$a_{ij}^K$是基于相对位置(i-j)学习的嵌入。这种方法在处理长序列时表现出更好的泛化能力。
在实际应用中,我们需要处理变长序列和防止信息泄露。这通过注意力掩码实现:
python复制def create_padding_mask(seq):
mask = tf.cast(tf.math.equal(seq, 0), tf.float32)
return mask[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len)
def create_look_ahead_mask(size):
mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
return mask # (seq_len, seq_len)
理解模型关注什么是调试和解释Transformer的重要方式。典型可视化方法包括:
python复制import matplotlib.pyplot as plt
def plot_attention_weights(attention_weights, sentence):
fig = plt.figure(figsize=(16, 8))
for h, head in enumerate(attention_weights):
ax = fig.add_subplot(2, 4, h+1)
ax.matshow(head, cmap='viridis')
ax.set_xticks(range(len(sentence)))
ax.set_yticks(range(len(sentence)))
ax.set_ylim(len(sentence)-1.5, -0.5)
ax.set_title(f'Head {h+1}')
plt.tight_layout()
plt.show()
尽管注意力机制非常强大,但仍存在一些固有局限:
目前一些有前景的改进方向包括:
我在实际应用中发现,对于特定任务,适当限制注意力的范围(如设置最大关注距离)往往能在保持性能的同时显著提升效率。此外,对注意力头进行正则化(如增加多样性惩罚项)可以防止多头注意力退化为少数头主导的情况。