1. Transformer 架构概述
Transformer 架构自2017年由Vaswani等人提出以来,已成为自然语言处理领域的基石模型。与传统RNN和CNN不同,Transformer完全基于自注意力机制,能够并行处理序列数据并捕获长距离依赖关系。在CS336课程作业中,我们需要从零开始实现一个完整的Transformer模型,这要求我们深入理解每个组件的数学原理和实现细节。
现代大型语言模型(如GPT、LLaMA等)都基于Transformer架构,但通常会进行一些改进。本次作业实现的版本包含以下核心创新:
- 使用RMSNorm替代传统LayerNorm
- 采用SwiGLU作为前馈网络激活函数
- 引入旋转位置编码(RoPE)
- 使用预归一化(Pre-Norm)的残差连接
这些改进使得模型在训练稳定性和表现力方面都有显著提升。下面我们将逐一拆解每个模块的技术细节。
2. 基础线性与嵌入模块
2.1 线性变换层(无偏置版本)
线性层是神经网络中最基础的组件,其数学表示为:
$$ y = xW^\top $$
其中:
- $x \in \mathbb{R}^{..., d_{in}}$ 是输入张量
- $W \in \mathbb{R}^{d_{out}, d_{in}}$ 是权重矩阵
- $y \in \mathbb{R}^{..., d_{out}}$ 是输出张量
在实现时需要注意:
-
权重初始化采用截断正态分布:
$$ W \sim \mathcal{N}(0, \frac{2}{d_{in} + d_{out}}) $$
并截断到$[-3\sigma, 3\sigma]$范围内 -
现代Transformer实现通常省略偏置项,这可以:
- 减少参数数量
- 提高计算效率
- 与LayerNorm/RMSNorm更好地配合
提示:在实际编码时,可以使用
nn.Linear并设置bias=False来创建无偏置线性层。初始化权重时要注意保持方差稳定,防止梯度爆炸或消失。
2.2 词嵌入层
词嵌入层将离散的token ID映射到连续的向量空间:
$$ \text{Embedding}(i) = W_e[i,:] $$
其中:
- $i$ 是token ID(整数)
- $W_e \in \mathbb{R}^{vocab_size, d_{model}}$ 是嵌入矩阵
- 输出维度为$d_{model}$(通常256-4096)
关键细节:
- 嵌入矩阵通常与最后的线性输出层共享权重(节省参数)
- 需要对嵌入进行缩放(乘以$\sqrt{d_{model}}$),防止初始阶段注意力分数过大
- 现代模型通常不对嵌入层使用偏置
2.3 RMSNorm(均方根归一化)
RMSNorm是LayerNorm的简化版本,计算更高效:
$$ \text{RMSNorm}(a_i) = \frac{a_i}{\text{RMS}(a)} \cdot g_i $$
其中:
$$ \text{RMS}(a) = \sqrt{\frac{1}{d_{model}}\sum_{i=1}^{d_{model}} a_i^2 + \epsilon} $$
与LayerNorm相比:
- 去除了均值中心化
- 仅使用均方根进行缩放
- 保留可学习的增益参数$g_i$
优势:
- 计算量减少约20-30%
- 在大多数情况下性能相当
- 更适合大规模模型训练
实现要点:
python复制class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / (norm + self.eps) * self.g
3. 前馈网络模块
3.1 SiLU与GLU激活函数
SiLU(Sigmoid Linear Unit)激活函数:
$$ \text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1+e^{-x}} $$
特点:
- 比ReLU更平滑
- 在负区间有非零输出
- 常用于现代Transformer
门控线性单元(GLU):
$$ \text{GLU}(x, W_1, W_2) = \sigma(W_1x) \odot W_2x $$
其中$\odot$是逐元素乘法。GLU通过门控机制:
- 允许模型选择性地传递信息
- 缓解梯度消失问题
- 增强非线性表达能力
3.2 SwiGLU前馈网络
SwiGLU结合了SiLU和GLU的优点:
$$ \text{FFN}_{\text{SwiGLU}}(x) = W_2(\text{SiLU}(W_1x) \odot W_3x) $$
参数配置:
- $W_1, W_3 \in \mathbb{R}^{d_{ff} \times d_{model}}$
- $W_2 \in \mathbb{R}^{d_{model} \times d_{ff}}$
- 通常$d_{ff} = \frac{8}{3}d_{model}$(约为2.67倍)
优势:
- 更强的非线性表达能力
- 更平滑的梯度流动
- 已成为LLM标准配置
实现示例:
python复制class SwiGLU(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
4. 位置编码模块
4.1 旋转位置编码(RoPE)
RoPE通过旋转矩阵将位置信息注入注意力计算:
对于位置$i$的查询向量$q^{(i)}$,旋转后为:
$$ q'^{(i)} = R^i q^{(i)} $$
其中旋转矩阵$R^i$作用于每对相邻维度:
$$
\begin{bmatrix}
q'{2k-1} \
q'
\end
\begin{bmatrix}
\cos\theta_{i,k} & -\sin\theta_{i,k} \
\sin\theta_{i,k} & \cos\theta_{i,k}
\end{bmatrix}
\begin{bmatrix}
q_{2k-1} \
q_{2k}
\end{bmatrix}
$$
角度计算:
$$ \theta_{i,k} = \frac{i}{\Theta^{(2k-2)/d}} $$
其中$\Theta$是预设常数(通常10000)
优势:
- 相对位置编码,可以处理任意长度序列
- 保持注意力分数的相对性
- 计算高效,可以融合到注意力计算中
实现技巧:
python复制def apply_rotary_emb(x, freqs):
# x: [batch, seq_len, n_heads, head_dim]
# freqs: [seq_len, head_dim//2]
x_rot = x[..., :, :, :x.shape[-1]//2]
x_pass = x[..., :, :, x.shape[-1]//2:]
x_rot = x_rot.reshape(*x_rot.shape[:-1], -1, 2)
x_pass = x_pass.reshape(*x_pass.shape[:-1], -1, 2)
# 应用旋转
x_rot = torch.stack(
[x_rot[..., 0]*freqs.cos() - x_rot[..., 1]*freqs.sin(),
x_rot[..., 0]*freqs.sin() + x_rot[..., 1]*freqs.cos()],
dim=-1)
return torch.cat([x_rot.flatten(-2), x_pass.flatten(-2)], dim=-1)
5. 注意力机制模块
5.1 数值稳定的Softmax
标准Softmax实现:
$$ \text{softmax}(v)_i = \frac{\exp(v_i - \max(v))}{\sum_j \exp(v_j - \max(v))} $$
优化技巧:
- 减去最大值防止数值溢出
- 使用对数空间计算提高数值稳定性
- 对于因果注意力,需要添加掩码$M$
5.2 缩放点积注意力
核心公式:
$$ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}} + M\right)V $$
维度:
- $Q \in \mathbb{R}^{n \times d_k}$
- $K \in \mathbb{R}^{m \times d_k}$
- $V \in \mathbb{R}^{m \times d_v}$
关键点:
- 缩放因子$\frac{1}{\sqrt{d_k}}$防止点积过大
- 掩码$M$用于实现因果注意力
- 实际实现时使用矩阵运算优化
5.3 多头注意力
将注意力分成$h$个头并行计算:
$$ \text{MultiHead}(Q,K,V) = \text{Concat}(head_1,...,head_h) $$
其中:
$$ head_i = \text{Attention}(QW_{Q,i}, KW_{K,i}, VW_{V,i}) $$
参数配置:
- $d_k = d_v = \frac{d_{model}}{h}$
- 输出维度保持$d_{model}$
优势:
- 并行捕捉不同子空间的信息
- 计算效率更高(可并行化)
- 表达能力更强
实现示例:
python复制class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
def forward(self, q, k, v, mask=None):
# 投影到多头
q = self.q_proj(q).view(*q.shape[:-1], self.n_heads, self.d_k)
k = self.k_proj(k).view(*k.shape[:-1], self.n_heads, self.d_k)
v = self.v_proj(v).view(*v.shape[:-1], self.n_heads, self.d_k)
# 计算注意力
scores = torch.einsum('...qd,...kd->...qk', q, k) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.einsum('...qk,...kd->...qd', attn, v)
# 合并多头
out = out.reshape(*out.shape[:-2], -1)
return self.out_proj(out)
6. Transformer模块与整体架构
6.1 Pre-Norm Transformer Block
预归一化结构:
python复制z = x + MultiHeadSelfAttention(RMSNorm(x))
y = z + FFN(RMSNorm(z))
与传统Post-Norm的区别:
- 归一化在残差连接前
- 训练更稳定
- 梯度流动更顺畅
- 已成为现代Transformer标准
6.2 完整Transformer语言模型
架构流程:
- 输入嵌入:$x_0 = \text{Embedding}(\text{token_ids})$
- 多层Transformer:$x_l = \text{TransformerBlock}(x_{l-1}), l=1..N$
- 最终归一化:$x_{final} = \text{RMSNorm}(x_N)$
- 输出logits:$\text{logits} = x_{final} W_{vocab}^\top$
关键参数:
- $d_{model}$:模型维度(如512、768、1024等)
- $n_{heads}$:注意力头数(通常64-128)
- $n_{layers}$:Transformer层数(12-48)
- $d_{ff}$:前馈网络维度(通常$\frac{8}{3}d_{model}$)
7. 训练与优化模块
7.1 交叉熵损失
语言建模损失:
$$ \ell(\theta; D) = \frac{1}{|D|m}\sum_{x\in D}\sum_{i=1}^m -\log p_\theta(x_{i+1}|x_{1:i}) $$
其中:
- $p_\theta(x_{i+1}|x_{1:i}) = \text{softmax}(o_i)[x_{i+1}]$
- $o_i$是位置$i$的logits向量
实现要点:
- 使用F.cross_entropy直接计算
- 注意处理padding位置
- 可以结合标签平滑技术
7.2 困惑度(Perplexity)
评估指标:
$$ \text{perplexity} = \exp\left(\frac{1}{m}\sum_{i=1}^m \ell_i\right) $$
解释:
- 可以理解为"平均分支因子"
- 完美预测时困惑度为1
- 随机猜测时困惑度为vocab_size
7.3 AdamW优化器
改进版Adam:
-
动量更新:
$$ m_t = \beta_1 m_{t-1} + (1-\beta_1)g_t $$ -
二阶矩更新:
$$ v_t = \beta_2 v_{t-1} + (1-\beta_2)g_t^2 $$ -
偏置校正:
$$ \hat{\alpha}_t = \alpha \cdot \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} $$ -
参数更新(带解耦权重衰减):
$$ \theta_{t+1} = \theta_t - \hat{\alpha}_t \cdot \frac{m_t}{\sqrt{v_t}+\epsilon} - \alpha\lambda\theta_t $$
默认超参数:
- $\beta_1=0.9$
- $\beta_2=0.999$
- $\epsilon=1e-8$
- $\lambda=0.01$
7.4 余弦学习率调度
带预热的余弦退火:
$$
\alpha_t =
\begin{cases}
\frac{t}{T_w}\alpha_{max} & t < T_w \
\alpha_{min} + \frac{1}{2}(1+\cos(\frac{t-T_w}{T_c-T_w}\pi))(\alpha_{max}-\alpha_{min}) & T_w \leq t \leq T_c \
\alpha_{min} & t > T_c
\end{cases}
$$
典型配置:
- $T_w$:500-10000步
- $T_c$:总训练步数
- $\alpha_{max}$:1e-4到6e-4
- $\alpha_{min}$:$\alpha_{max}/10$
7.5 梯度裁剪
防止梯度爆炸:
$$ g \leftarrow g \cdot \min\left(1, \frac{M}{|g|_2+\epsilon}\right) $$
典型值:
- $M=1.0$
- $\epsilon=1e-6$
8. 文本生成模块
8.1 温度采样
调整softmax温度:
$$ \text{softmax}(v, \tau)_i = \frac{\exp(v_i/\tau)}{\sum_j \exp(v_j/\tau)} $$
影响:
- $\tau \to 0$:贪心搜索
- $\tau=1$:标准采样
- $\tau>1$:更均匀分布
8.2 Top-p(核采样)
从累积概率超过p的最小token集合中采样:
- 对logits排序
- 计算累积概率
- 选择超过阈值p的最小集合
- 从中采样
优势:
- 动态调整候选集大小
- 避免低概率token
- 生成质量更高
实现示例:
python复制def top_p_sampling(logits, p=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 移除累积概率超过p的token
sorted_indices_to_remove = cum_probs > p
# 确保至少保留一个token
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
-1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
9. 实现经验与技巧
9.1 调试技巧
- 梯度检查:初期设置很小的batch size和序列长度,检查梯度是否合理
- 过拟合测试:在极小数据集上(如10个样本)测试能否过拟合
- 数值稳定性:添加assert检查NaN和inf
- 内存分析:使用torch.cuda.memory_summary()监控显存使用
9.2 性能优化
- 融合操作:尽可能使用矩阵运算而非循环
- 内存效率:使用in-place操作和梯度检查点
- 混合精度:启用AMP自动混合精度训练
- 并行化:使用DataParallel或DistributedDataParallel
9.3 常见问题解决
-
损失不下降:
- 检查学习率是否合适
- 验证数据加载是否正确
- 检查初始化是否合理
-
梯度爆炸:
- 减小学习率
- 增加梯度裁剪阈值
- 检查权重初始化
-
生成质量差:
- 调整温度参数
- 尝试不同的采样策略
- 检查模型是否充分训练
10. 扩展与进阶
10.1 模型压缩技术
- 知识蒸馏:训练小模型模仿大模型行为
- 量化:将FP32转为INT8/INT4
- 剪枝:移除不重要的权重
- 参数共享:在不同层间共享部分参数
10.2 高效注意力变体
- 稀疏注意力:只计算部分位置的注意力
- 局部注意力:限制注意力窗口大小
- 内存高效的注意力:如FlashAttention
- 线性注意力:近似计算降低复杂度
10.3 持续学习技术
-
参数高效微调:
- LoRA:低秩适配
- Adapter:插入小型网络模块
- Prefix-tuning:学习连续提示
-
灾难性遗忘缓解:
- 弹性权重固化(EWC)
- 回放缓冲区
- 正则化技术
通过深入理解这些Transformer核心模块和实现细节,我们能够更好地构建、调试和优化自己的语言模型。在实际应用中,还需要根据具体任务和数据特点进行调整和创新。