1. 项目概述
在自然语言处理领域,文本分类是一项基础而重要的任务。本文将详细介绍如何使用Transformer架构实现一个中文文本分类系统。不同于传统的RNN或CNN方法,Transformer完全基于注意力机制,能够更好地捕捉长距离依赖关系。我们将从数据准备开始,逐步构建完整的模型,并分享在实际实现过程中的经验教训。
这个项目适合已经了解Transformer基础概念(如QKV注意力机制)的开发者。通过本实现,你将掌握:
- 中文文本分类任务的完整处理流程
- Transformer核心模块的代码级实现
- 实际训练中的调优技巧和常见问题解决方法
2. 数据准备与预处理
2.1 数据集选择与清洗
我们使用THUCNews数据集的一个子集,包含财经、家居、科技和教育四个类别。原始数据是分文件夹存储的文本文件,需要先转换为CSV格式以便处理。
数据清洗是文本处理中至关重要的一步。中文文本常见的清洗需求包括:
- 去除特殊字符和乱码
- 统一标点符号
- 处理换行和空白字符
python复制import os, csv, random, re, pathlib
def clean_text(text):
# 保留中英文、数字、中文标点符号
clean_pat = re.compile(r'[^\u4e00-\u9fa5a-zA-Z0-9'
'\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\uff01\u2014\u2026\u2018\u2019\uff0d'
',。!?、;:"'()——…:]')
# 统一替换各种空白为中文句号
text = re.sub(r'\s+', '。', text)
return clean_pat.sub('', text).strip()
注意:在实际项目中,文本清洗规则需要根据具体数据特点调整。过于严格的清洗可能会损失有意义的语义信息。
2.2 数据集划分与存储
我们将数据按8:2的比例划分为训练集和验证集,并保存为CSV文件:
python复制def get_train_val_csv():
root = 'THUCNews' # 数据根目录
samples_per_class = 500 # 每类样本数
train_ratio = 0.8 # 训练集比例
out_dir = 'split_csv' # 输出目录
pathlib.Path(out_dir).mkdir(exist_ok=True)
with open(f'{out_dir}/train.csv', 'w', encoding='utf-8') as f_train, \
open(f'{out_dir}/val.csv', 'w', encoding='utf-8') as f_val:
writer_train = csv.writer(f_train)
writer_val = csv.writer(f_val)
writer_train.writerow(['text', 'label'])
writer_val.writerow(['text', 'label'])
for label in ['财经', '家居', '科技', '教育']:
files = os.listdir(f'{root}/{label}')
random.shuffle(files)
if samples_per_class:
files = files[:samples_per_class]
split = int(len(files) * train_ratio)
for i, filename in enumerate(files):
with open(f'{root}/{label}/{filename}', encoding='utf-8') as f:
text = clean_text(f.read())
if not text:
continue
if i < split:
writer_train.writerow([text, label])
else:
writer_val.writerow([text, label])
这样处理后,我们得到两个CSV文件,分别包含训练集和验证集的文本及其标签。
3. 数据集类实现
3.1 词表构建
中文文本需要先分词才能构建词表。我们使用jieba进行分词,并统计词频:
python复制from collections import Counter
import jieba
class Vocabulary:
def __init__(self, min_freq=10):
self.word2idx = {'<pad>': 0, '<unk>': 1}
self.idx2word = {0: '<pad>', 1: '<unk>'}
self.min_freq = min_freq
def build(self, texts):
counter = Counter()
for text in texts:
words = jieba.lcut(text)
counter.update(words)
for word, freq in counter.items():
if freq >= self.min_freq and word not in self.word2idx:
idx = len(self.word2idx)
self.word2idx[word] = idx
self.idx2word[idx] = word
词表构建的几个关键点:
- 设置最小词频(min_freq)过滤低频词,减少词表大小
- 保留特殊token:
<pad>用于填充,<unk>表示未知词 - 同时维护word2idx和idx2word两个字典方便双向查找
3.2 Dataset类实现
完整的Dataset类需要实现以下功能:
- 加载CSV数据
- 文本分词和编码
- 序列截断和填充
- 标签映射
python复制from torch.utils.data import Dataset
import pandas as pd
import torch
class TextClassificationDataset(Dataset):
def __init__(self, csv_path, vocab, max_len=100):
self.data = pd.read_csv(csv_path)
self.vocab = vocab
self.max_len = max_len
# 构建label到id的映射
self.labels = self.data['label'].unique()
self.label2idx = {label: idx for idx, label in enumerate(self.labels)}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text = self.data.iloc[idx]['text']
label = self.data.iloc[idx]['label']
# 分词并转换为id序列
words = jieba.lcut(text)
word_ids = [self.vocab.word2idx.get(word, self.vocab.word2idx['<unk>'])
for word in words]
# 截断或填充
if len(word_ids) > self.max_len:
word_ids = word_ids[:self.max_len]
else:
word_ids = word_ids + [self.vocab.word2idx['<pad>']] * (self.max_len - len(word_ids))
return torch.tensor(word_ids, dtype=torch.long), self.label2idx[label]
实际经验:在构建Dataset时,建议将词表和标签映射单独保存,这样在预测时可以直接加载使用,避免重新构建。
4. Transformer模型实现
4.1 位置编码
Transformer没有内置的位置信息感知能力,需要通过位置编码注入序列位置信息:
python复制import math
import torch
import torch.nn as nn
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
位置编码的特点:
- 使用正弦和余弦函数的组合,可以学习到相对位置关系
- 不同维度使用不同的频率,可以捕捉不同粒度的位置信息
- 直接与词向量相加,不影响原始语义
4.2 多头注意力机制
多头注意力是Transformer的核心组件,允许模型同时关注不同位置的多种关系:
python复制class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x):
batch_size, seq_len, _ = x.size()
return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
def forward(self, Q, K, V, mask=None):
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
# 缩放点积注意力
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
attn_probs = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_probs, V)
# 合并多头
output = output.transpose(1, 2).contiguous()
output = output.view(output.size(0), -1, self.d_model)
return self.W_o(output)
关键实现细节:
- 使用线性变换生成Q、K、V矩阵
- 缩放点积注意力防止梯度消失
- 支持注意力掩码,可用于处理填充位置
- 最后合并多头输出并通过线性层
4.3 编码器层
每个编码器层包含多头注意力和前馈网络,并应用残差连接和层归一化:
python复制class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力子层
attn_output = self.self_attn(x, x, x, mask)
x = x + self.dropout(attn_output)
x = self.norm1(x)
# 前馈网络子层
ffn_output = self.ffn(x)
x = x + self.dropout(ffn_output)
x = self.norm2(x)
return x
编码器层的设计要点:
- 每个子层都有残差连接,缓解梯度消失
- 层归一化在残差连接之后应用
- 使用Dropout防止过拟合
- 前馈网络提供非线性变换能力
4.4 完整Transformer模型
将各个组件组合成完整的文本分类模型:
python复制class TransformerClassifier(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len, num_classes, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.classifier = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_model, num_classes)
)
def forward(self, x):
# 生成padding mask
mask = (x != 0).unsqueeze(1).unsqueeze(2)
# 嵌入层
x = self.embedding(x)
x = self.pos_encoding(x)
# 编码器
for layer in self.encoder_layers:
x = layer(x, mask)
# 平均池化
x = x.mean(dim=1)
# 分类器
return self.classifier(x)
模型特点:
- 使用平均池化聚合序列信息,比最大池化更稳定
- 分类器使用两层MLP增强表达能力
- 自动生成padding mask忽略填充位置
- 支持自定义模型深度和宽度
5. 模型训练与评估
5.1 训练流程实现
完整的训练过程包括数据加载、模型初始化、训练循环和验证:
python复制def train_model():
# 初始化词表
train_df = pd.read_csv('train.csv')
vocab = Vocabulary(min_freq=10)
vocab.build(train_df['text'].tolist())
# 创建数据集
train_dataset = TextClassificationDataset('train.csv', vocab, max_len=100)
val_dataset = TextClassificationDataset('val.csv', vocab, max_len=100)
# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
# 模型配置
model = TransformerClassifier(
vocab_size=len(vocab.word2idx),
d_model=256,
num_heads=8,
num_layers=3,
d_ff=512,
max_len=100,
num_classes=4
)
# 训练设置
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
# 训练循环
for epoch in range(20):
model.train()
total_loss = 0
for batch in train_loader:
inputs, labels = batch
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
# 验证
model.eval()
val_acc = 0
with torch.no_grad():
for batch in val_loader:
inputs, labels = batch
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
val_acc += (preds == labels).sum().item()
val_acc /= len(val_dataset)
print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Val Acc: {val_acc:.4f}')
scheduler.step()
训练技巧:
- 使用学习率调度器动态调整学习率
- 每个epoch后验证模型性能
- 记录训练损失和验证准确率
- 适当增加训练epochs直到验证准确率稳定
5.2 常见问题与解决方案
在实际训练中可能会遇到以下问题:
问题1:模型完全不学习(准确率随机)
- 原因:可能是池化方式不当(如使用最大池化导致梯度消失)
- 解决:改用平均池化或尝试其他聚合方式
问题2:验证准确率波动大
- 原因:学习率可能过高或batch size太小
- 解决:降低学习率,增大batch size,或使用梯度裁剪
问题3:训练集表现好但验证集差
- 原因:模型过拟合
- 解决:增加Dropout比例,添加L2正则化,或使用早停法
问题4:GPU内存不足
- 原因:序列长度或batch size太大
- 解决:减小max_len或batch size,或使用梯度累积
6. 模型优化与扩展
6.1 性能优化技巧
-
学习率预热:在训练初期使用较小的学习率,逐步增大
python复制def warmup_lr(step, warmup_steps=4000, d_model=256): return min(step ** -0.5, step * (warmup_steps ** -1.5)) * (d_model ** -0.5) -
标签平滑:防止模型对预测结果过于自信
python复制class LabelSmoothingLoss(nn.Module): def __init__(self, smoothing=0.1): super().__init__() self.smoothing = smoothing def forward(self, logits, targets): log_probs = F.log_softmax(logits, dim=-1) nll_loss = -log_probs.gather(dim=-1, index=targets.unsqueeze(1)) smooth_loss = -log_probs.mean(dim=-1) loss = (1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss return loss.mean() -
混合精度训练:减少显存占用,加快训练速度
python复制from torch.cuda.amp import GradScaler, autocast scaler = GradScaler() with autocast(): outputs = model(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
6.2 模型扩展方向
-
使用预训练词向量:
python复制
embedding = nn.Embedding.from_pretrained(load_pretrained_vectors()) -
添加CNN分支:结合局部特征和全局注意力
python复制class CNNBranch(nn.Module): def __init__(self, d_model): super().__init__() self.conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1) def forward(self, x): return self.conv(x.transpose(1, 2)).transpose(1, 2) -
层次化Transformer:不同层使用不同注意力头数
python复制self.encoder_layers = nn.ModuleList([ EncoderLayer(d_model, num_heads[i], d_ff, dropout) for i in range(num_layers) ])
7. 实际应用建议
-
生产环境部署:
- 使用ONNX或TorchScript导出模型
- 实现批处理预测提高吞吐量
- 添加缓存机制减少重复计算
-
持续改进:
- 定期用新数据重新训练模型
- 监控预测结果的分布变化
- 建立A/B测试框架评估模型改进
-
错误分析:
- 收集错误预测样本进行分析
- 识别模型的主要错误模式
- 针对性改进数据或模型结构
这个Transformer文本分类实现虽然相对简单,但包含了核心组件和完整流程。在实际项目中,可以根据具体需求调整模型结构、优化训练策略,并持续迭代改进。