在深度学习领域,图像分类任务一直是计算机视觉的基础课题。PyTorch作为当前主流的深度学习框架,其灵活的张量计算和自动微分机制为各类分类任务提供了高效实现方案。交叉熵损失函数(Cross-Entropy Loss)作为分类任务中最常用的损失函数之一,能够有效衡量模型预测概率分布与真实标签之间的差异。
本文将深入探讨如何在PyTorch中利用交叉熵损失函数实现多分类与二分类任务。不同于简单的API调用教程,我会结合自己在大规模图像分类项目中的实战经验,详细解析损失函数背后的数学原理、PyTorch中的实现机制,以及实际应用中的关键技巧。无论你是刚入门深度学习的新手,还是希望优化现有分类模型效果的从业者,都能从中获得可直接落地的解决方案。
交叉熵源于信息论中的KL散度(Kullback-Leibler Divergence),用于衡量两个概率分布之间的差异。在分类任务中,我们期望模型的预测概率分布尽可能接近真实的标签分布。对于单个样本,交叉熵损失的计算公式为:
$$
L = -\sum_{c=1}^{C} y_c \log(p_c)
$$
其中,$C$表示类别总数,$y_c$是样本属于类别$c$的真实标签(one-hot编码),$p_c$是模型预测该样本属于类别$c$的概率。
在多分类任务中(如CIFAR-10、ImageNet),这个公式直接适用;而在二分类任务中(如医学图像中的病灶检测),公式可以简化为:
$$
L = -[y \log(p) + (1-y) \log(1-p)]
$$
关键理解:交叉熵损失对错误预测施加了"对数惩罚",预测概率与真实标签差异越大,损失值增长越显著。这种特性使其特别适合分类问题。
PyTorch提供了两种主要的交叉熵实现方式,对应不同的使用场景:
nn.CrossEntropyLoss (适用于多分类)
nn.BCEWithLogitsLoss (适用于二分类)
python复制# 多分类任务典型用法
criterion = nn.CrossEntropyLoss()
outputs = model(inputs) # 未经softmax的原始输出
loss = criterion(outputs, labels)
# 二分类任务典型用法
criterion = nn.BCEWithLogitsLoss()
outputs = model(inputs) # 未经sigmoid的原始输出
loss = criterion(outputs, labels.float())
以CIFAR-10数据集为例,我们首先需要正确处理数据并构建适合的模型架构:
python复制import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
# 数据加载与增强
transform = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=2)
# 简易CNN模型
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 256)
self.fc2 = nn.Linear(256, 10) # CIFAR-10有10个类别
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = self.fc2(x) # 注意:不包含softmax层
return x
model = Net()
在多分类任务中,正确设置损失函数和优化器是关键:
python复制device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(20):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99:
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
running_loss = 0.0
实战技巧:在训练初期可以添加学习率warmup策略,逐步提高学习率以避免初期的不稳定。同时,对于大型数据集,建议使用混合精度训练(AMP)来加速训练过程。
训练完成后,我们需要在测试集上评估模型性能:
python复制testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(
testset, batch_size=128, shuffle=False, num_workers=2)
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy on test images: {100 * correct / total:.2f}%')
对于更全面的评估,建议计算每个类别的精确度(precision)、召回率(recall)和F1分数:
python复制from sklearn.metrics import classification_report
all_preds = []
all_labels = []
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
print(classification_report(all_labels, all_preds))
二分类任务的数据准备与多分类有所不同,主要体现在标签处理上。以猫狗分类为例:
python复制from torchvision.datasets import ImageFolder
# 假设数据目录结构为:
# data/
# train/
# cat/
# dog/
# val/
# cat/
# dog/
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
trainset = ImageFolder('data/train', transform=transform)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=32, shuffle=True, num_workers=2)
# 验证集
valset = ImageFolder('data/val', transform=transform)
valloader = torch.utils.data.DataLoader(
valset, batch_size=32, shuffle=False, num_workers=2)
对于二分类任务,模型最后一层只需要一个输出单元:
python复制class BinaryClassifier(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(128 * 56 * 56, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1) # 单个输出单元
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x # 输出未经sigmoid的logits
使用BCEWithLogitsLoss需要注意标签的格式和数值范围:
python复制model = BinaryClassifier().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练循环
for epoch in range(15):
model.train()
running_loss = 0.0
for inputs, labels in trainloader:
inputs = inputs.to(device)
labels = labels.float().unsqueeze(1).to(device) # 转换为float并添加维度
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
# 验证阶段
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in valloader:
inputs = inputs.to(device)
labels = labels.float().unsqueeze(1).to(device)
outputs = model(inputs)
val_loss += criterion(outputs, labels).item()
preds = torch.sigmoid(outputs) > 0.5 # 将logits转换为概率并阈值化
correct += (preds == labels).sum().item()
total += labels.size(0)
print(f'Epoch {epoch+1}: '
f'Train Loss: {running_loss/len(trainloader):.4f}, '
f'Val Loss: {val_loss/len(valloader):.4f}, '
f'Val Acc: {100.*correct/total:.2f}%')
重要提示:BCEWithLogitsLoss已经包含了sigmoid操作和数值稳定性的优化,因此不要在模型最后添加sigmoid层,也不要在损失计算前手动应用sigmoid。
在实际应用中,数据集经常存在类别不平衡问题。以医学图像分类为例,正常样本可能远多于异常样本。PyTorch提供了几种应对方案:
python复制# 假设类别0和类别1的样本比例为10:1
class_weights = torch.tensor([1.0, 10.0]).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
python复制from torch.utils.data.sampler import WeightedRandomSampler
# 假设labels是包含所有样本标签的列表
class_counts = torch.bincount(torch.tensor(labels))
class_weights = 1. / class_counts.float()
sample_weights = class_weights[labels]
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
python复制class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduction == 'mean':
return torch.mean(F_loss)
elif self.reduction == 'sum':
return torch.sum(F_loss)
else:
return F_loss
标签平滑(Label Smoothing)是一种正则化技术,可以防止模型对训练标签过度自信,提高泛化能力:
python复制class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, epsilon=0.1, reduction='mean'):
super().__init__()
self.epsilon = epsilon
self.reduction = reduction
def forward(self, logits, targets):
n_classes = logits.size(-1)
log_preds = F.log_softmax(logits, dim=-1)
loss = -log_preds.sum(dim=-1)
nll = F.nll_loss(log_preds, targets, reduction='none')
loss = (1 - self.epsilon) * nll + self.epsilon * loss / n_classes
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
使用混合精度训练可以显著减少显存占用并加速训练过程:
python复制from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
for epoch in range(epochs):
for inputs, labels in trainloader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
with autocast():
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
理解模型如何做出分类决策同样重要。可以使用Grad-CAM等技术可视化模型关注区域:
python复制import matplotlib.pyplot as plt
from torchcam.methods import GradCAM
# 选择目标层(通常是最后一个卷积层)
cam_extractor = GradCAM(model, 'features.3')
with torch.no_grad():
out = model(inputs.unsqueeze(0).to(device))
activation_map = cam_extractor(out.squeeze(0).argmax().item(), out)
# 可视化
plt.imshow(inputs.permute(1, 2, 0).cpu().numpy())
plt.imshow(activation_map[0].squeeze(0).cpu().numpy(), alpha=0.5, cmap='jet')
plt.show()
学习率设置不当
模型容量不足
数据预处理问题
梯度消失/爆炸
数据增强
正则化技术
模型简化
梯度裁剪
python复制torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
损失值监控
中间结果可视化
数据质量优先
基准模型建立
持续监控
在长期的项目实践中,我发现交叉熵损失虽然简单,但通过合理的调整和配套技术的使用,能够在绝大多数分类任务中取得优秀的表现。关键在于理解数据特性、选择合适的模型容量,并持续监控训练过程。