最近在研究Meta开源的LLaMA2大语言模型架构,决定动手实现一个简化版本。这个项目将使用PyTorch框架,从零开始构建一个基于Transformer架构的自回归语言模型。与原始LLaMA2的70B参数量不同,我们的实现将是一个轻量级版本,适合在单张消费级GPU上运行和调试。
为什么选择实现LLaMA2?首先,它的架构相对简洁但性能强大;其次,Meta开源了基础模型权重,方便我们验证实现正确性;最重要的是,通过亲手实现可以深入理解现代大语言模型的核心机制。
我们先定义模型的基础配置类,这决定了模型的规模和行为:
python复制from transformers import PretrainedConfig
class ModelConfig(PretrainedConfig):
model_type = "Tiny-K"
def __init__(
self,
dim: int = 768, # 模型隐藏层维度
n_layers: int = 12, # Transformer层数
n_heads: int = 16, # 注意力头数
n_kv_heads: int = 8, # 键值头数(GQA设计)
vocab_size: int = 6144, # 词表大小
hidden_dim: int = None, # FFN隐藏层维度
multiple_of: int = 64, # 确保维度对齐
norm_eps: float = 1e-5, # RMSNorm的epsilon
max_seq_len: int = 512, # 最大序列长度
dropout: float = 0.0, # Dropout率
flash_attn: bool = True, # 使用FlashAttention
**kwargs,
):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.dropout = dropout
self.flash_attn = flash_attn
super().__init__(**kwargs)
这个配置类继承自HuggingFace的PretrainedConfig,便于后续与transformers生态集成。关键参数说明:
dim=768:这是一个中等规模的模型,适合教育目的n_kv_heads=8:采用分组查询注意力(GQA),键值头数少于查询头数flash_attn=True:默认使用FlashAttention加速计算LLaMA2的核心创新点包括:
我们将在后续章节逐一实现这些组件。
传统Transformer使用LayerNorm进行归一化,计算公式为:
$$
\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sigma + \epsilon} + \beta
$$
其中$\mu$是均值,$\sigma$是标准差。而RMSNorm(Root Mean Square Layer Normalization)是LayerNorm的简化版本,去除了均值中心化:
$$
\text{RMS}(x) = \sqrt{\frac{1}{n}\sum_{i=1}^n x_i^2 + \epsilon} \
\bar{x} = \frac{x}{\text{RMS}(x)} \
y = \gamma \odot \bar{x}
$$
RMSNorm的优势:
python复制import torch.nn as nn
import torch
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
# 可学习的缩放参数,初始化为全1
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# 计算RMS值并归一化
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# 保持输入数据类型一致
output = self._norm(x.float()).type_as(x)
return output * self.weight
实现要点:
torch.rsqrt计算平方根的倒数,比分开计算更高效type_as保持输入输出数据类型一致weight参数允许模型调整归一化后的尺度LLaMA2采用了分组查询注意力(Grouped-Query Attention),这是对传统多头注意力(MHA)和多查询注意力(MQA)的折中:
| 类型 | 查询头(Q) | 键头(K) | 值头(V) | 特点 |
|---|---|---|---|---|
| MHA | N | N | N | 质量高但KV缓存大 |
| MQA | N | 1 | 1 | 高效但质量下降 |
| GQA | N | G | G | 平衡质量与效率 |
我们的实现采用G=8,即8组键值头共享查询头。
由于GQA中K、V的头数少于Q,需要通过repeat_kv函数进行维度对齐:
python复制def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
# 添加新维度并扩展
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
这个操作将KV头复制n_rep次,使最终头数与Q一致。例如:
RoPE通过旋转矩阵将位置信息编码到注意力计算中:
实现分为三个部分:
python复制def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
# 计算频率向量
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 计算所有位置的角度
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
# 返回cos和sin值
return torch.cos(freqs), torch.sin(freqs)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
# 调整形状以便广播
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 将QK转为复数形式
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
# 调整频率张量形状
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
# 应用旋转
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
# 合并结果
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
结合上述组件,我们实现LLaMA2的注意力层:
python复制class Attention(nn.Module):
def __init__(self, args: ModelConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
# 线性投影层
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.dropout = args.dropout
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
# 检查是否支持FlashAttention
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
# 手动实现因果掩码
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
bsz, seqlen, _ = x.shape
# 投影QKV
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# 应用RoPE
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
# 重复KV头
xk = repeat_kv(xk, self.n_rep)
xv = repeat_kv(xv, self.n_rep)
# 调整维度顺序
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# 注意力计算
if self.flash:
output = torch.nn.functional.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True
)
else:
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = scores + self.mask[:, :, :seqlen, :seqlen]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv)
# 合并多头输出
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
output = self.wo(output)
output = self.resid_dropout(output)
return output
LLaMA2使用SwiGLU激活的MLP代替传统FFN:
$$
\text{SwiGLU}(x) = (W_2x) \odot \text{SiLU}(W_1x)
$$
其中$\odot$是逐元素乘法,SiLU是Sigmoid Linear Unit:
$$
\text{SiLU}(x) = x \cdot \sigma(x)
$$
python复制class MLP(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
关键点:
multiple_of的倍数F.silu实现SiLU激活函数w3提供门控信号,动态调整各维度的重要性每个Decoder层包含:
python复制class DecoderLayer(nn.Module):
def __init__(self, layer_id: int, args: ModelConfig):
super().__init__()
self.attention = Attention(args)
self.feed_forward = MLP(
dim=args.dim,
hidden_dim=args.hidden_dim,
multiple_of=args.multiple_of,
dropout=args.dropout,
)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(self, x, freqs_cos, freqs_sin):
# 注意力子层
h = x + self.attention(self.attention_norm(x), freqs_cos, freqs_sin)
# FFN子层
out = h + self.feed_forward(self.ffn_norm(h))
return out
python复制class Transformer(PreTrainedModel):
config_class = ModelConfig
def __init__(self, args: ModelConfig = None):
super().__init__(args)
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
# Token嵌入
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.dropout = nn.Dropout(args.dropout)
# Decoder层堆叠
self.layers = nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(DecoderLayer(layer_id, args))
# 输出层
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
# 权重共享
self.tok_embeddings.weight = self.output.weight
# 预计算RoPE频率
freqs_cos, freqs_sin = precompute_freqs_cis(
args.dim // args.n_heads, args.max_seq_len
)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
# 初始化权重
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * args.n_layers))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens: torch.Tensor, targets=None):
bsz, seqlen = tokens.shape
# 嵌入层
h = self.tok_embeddings(tokens)
h = self.dropout(h)
# 获取RoPE频率
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]
# 逐层处理
for layer in self.layers:
h = layer(h, freqs_cos, freqs_sin)
# 归一化
h = self.norm(h)
if targets is not None:
# 训练模式
logits = self.output(h)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=0
)
else:
# 推理模式(只计算最后一个token)
logits = self.output(h[:, [-1], :])
loss = None
return logits, loss
@torch.inference_mode()
def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
# 截断过长上下文
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
# 前向传播
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
# Top-k采样
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# 采样
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
python复制config = ModelConfig(
dim=768,
n_layers=12,
n_heads=12,
n_kv_heads=4,
vocab_size=32000,
max_seq_len=2048,
dropout=0.1
)
model = Transformer(config).to('cuda')
python复制optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = torch.cuda.amp.GradScaler()
for batch in dataloader:
inputs, targets = batch
inputs, targets = inputs.to('cuda'), targets.to('cuda')
optimizer.zero_grad()
with torch.cuda.amp.autocast():
logits, loss = model(inputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
python复制prompt = "The future of AI is"
input_ids = tokenizer.encode(prompt, return_tensors='pt').to('cuda')
output_ids = model.generate(
input_ids,
max_new_tokens=50,
temperature=0.7,
top_k=40
)
print(tokenizer.decode(output_ids[0]))
梯度检查点:在训练大模型时启用
python复制from torch.utils.checkpoint import checkpoint
def forward(self, x):
return checkpoint(self._forward, x)
混合精度训练:使用AMP减少显存占用
python复制with torch.cuda.amp.autocast():
outputs = model(inputs)
梯度裁剪:防止梯度爆炸
python复制torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
学习率预热:前1%的训练步线性增加学习率
权重初始化:关键层使用特殊初始化
python复制nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * n_layers))
症状:损失值出现NaN或剧烈波动
解决方案:
症状:生成文本不连贯或重复
解决方案:
症状:CUDA out of memory错误
解决方案:
这个实现虽然简化,但包含了LLaMA2的核心创新点。通过亲手实现这些组件,我对现代大语言模型的工作原理有了更深入的理解。在实际项目中,建议从这个小模型开始实验,逐步扩展到更大规模。