作为一名长期从事深度学习研究的工程师,我一直对大型语言模型的内部工作机制充满好奇。最近,Andrej Karpathy发布了一个极具启发性的4小时视频教程,详细演示了如何从零开始构建GPT-2模型。这个124M参数版本的复现过程不仅揭示了现代语言模型的核心架构,更重要的是展示了如何用PyTorch高效实现这些复杂组件。
本文将深入解析Karpathy视频第一部分的核心代码,带你逐行理解GPT-2架构的实现细节。不同于原版GPT-2使用的TensorFlow实现,我们将使用更易调试的PyTorch框架重构模型。通过这个项目,你将掌握:
任何优秀的深度学习项目都应该从清晰的配置定义开始。GPTConfig类使用Python的dataclass装饰器,定义了模型的核心超参数:
python复制@dataclass
class GPTConfig:
block_size: int = 1024 # 最大上下文长度
vocab_size: int = 50257 # 词表大小(GPT-2标准)
n_layer: int = 12 # Transformer层数
n_head: int = 12 # 注意力头数
n_embd: int = 768 # 嵌入维度
这些参数的选择并非随意:
block_size=1024:这是GPT-2处理的最大token序列长度,超过此长度需要特殊处理vocab_size=50257:对应GPT-2分词器的词汇量(50,000基础词+256字节+1特殊token)n_layer=12和n_head=12:平衡模型深度和计算效率的经验值n_embd=768:每个token的向量表示维度GPT类继承自PyTorch的nn.Module,构成了模型的主体框架:
python复制class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict({
'wte': nn.Embedding(config.vocab_size, config.n_embd), # token嵌入
'wpe': nn.Embedding(config.block_size, config.n_embd), # 位置嵌入
'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
'ln_f': nn.LayerNorm(config.n_embd) # 最终层归一化
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
关键设计要点:
每个Transformer块包含以下关键组件:
python复制class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
与原始论文的主要区别:
CausalSelfAttention类实现了带掩码的多头注意力:
python复制class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) # Q,K,V投影
self.c_proj = nn.Linear(config.n_embd, config.n_embd) # 输出投影
self.register_buffer('bias', torch.tril(torch.ones(config.block_size, config.block_size)))
实现技巧:
MLP类实现了Transformer中的位置感知前馈网络:
python复制class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) # 扩展
self.gelu = nn.GELU(approximate='tanh')
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) # 收缩
设计考虑:
正确的初始化对训练稳定性至关重要:
python复制def _init_weights(self, module):
if isinstance(module, nn.Linear):
std = 0.02
if hasattr(module, "NANOGPT_SCALE_INIT"):
std *= (2 * self.config.n_layer) ** -0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
初始化规则:
模型的前向传播清晰展示了信息流动:
python复制def forward(self, idx, targets=None):
B, T = idx.size() # 批大小,序列长度
pos = torch.arange(0, T, device=idx.device)
pos_emb = self.transformer.wpe(pos) # 位置嵌入
tok_emb = self.transformer.wte(idx) # token嵌入
x = tok_emb + pos_emb # 合并嵌入
for block in self.transformer.h: # 通过所有Transformer块
x = block(x)
x = self.transformer.ln_f(x) # 最终归一化
logits = self.lm_head(x) # 输出投影
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
关键点:
高效的DataLoader对训练至关重要:
python复制class DataLoaderLite:
def __init__(self, B, T):
self.B, self.T = B, T
with open('shakespeare.txt') as f:
text = f.read()
tokens = tiktoken.get_encoding('gpt2').encode(text)
self.tokens = torch.tensor(tokens)
def next_batch(self):
buf = self.tokens[self.current_position : self.current_position + self.B*self.T + 1]
x = buf[:-1].view(self.B, self.T)
y = buf[1:].view(self.B, self.T)
# 更新位置并处理循环逻辑
return x, y
设计特点:
自回归生成是LLM的核心能力:
python复制def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.config.block_size:] # 截断到最大长度
logits, _ = self(idx_cond)
logits = logits[:, -1, :] # 取最后一个时间步
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
生成策略:
在实现过程中,以下调试方法非常有用:
形状断言:在每个关键步骤后添加形状检查
python复制assert q.shape == (B, self.n_head, T, C // self.n_head)
梯度检查:监控各层梯度范数,发现消失/爆炸问题
python复制print(f"Gradient norm: {torch.norm(param.grad)}")
激活统计:记录各层激活的均值和方差
python复制print(f"Activation mean: {x.mean()}, std: {x.std()}")
问题1:训练初期损失不下降
问题2:生成文本重复
问题3:GPU内存不足
完成基础实现后,可以考虑以下进阶改进:
这个实现虽然精简,但包含了现代语言模型的所有核心概念。通过深入理解这些基础组件,你将能够更好地理解和改进更复杂的语言模型架构。