1. 半监督学习在食物分类中的应用背景
在计算机视觉领域,食物分类是一个具有挑战性但又极具实用价值的任务。传统的监督学习方法需要大量标注数据,而数据标注往往需要耗费大量人力和时间成本。在实际项目中,我们常常会遇到这样的困境:有大量未标注的图像数据,但只有少量标注样本。这正是半监督学习大显身手的地方。
我最近完成的一个食物分类项目中,初始标注数据只有280张图片(共11类,每类约25张),但同时有6000多张未标注的食物图片。直接丢弃这些未标注数据显然太浪费,而全部人工标注又不现实。于是,我决定采用半监督学习的方法来充分利用这些未标注数据。
关键思路:当模型在验证集上表现足够好时(准确率超过阈值),我们就认为它已经具备一定的判别能力,可以用它来预测未标注数据的伪标签(pseudo-label),并将高置信度的预测结果加入训练集。
2. 半监督学习方案设计
2.1 整体架构设计
我的方案基于以下几个核心组件:
- 基础数据集类:扩展原有的food_Dataset类,增加对无标签数据的支持
- 半监督数据集生成器:SemiDataset类负责从无标签数据中筛选高置信度样本
- 改进的训练流程:修改train_val函数,在验证阶段有条件地生成半监督数据集
这种设计有几个明显优势:
- 保持原有代码结构,改动最小化
- 半监督数据生成与主训练流程解耦
- 可灵活调整置信度阈值和生成频率
2.2 关键参数选择
在实现过程中,有几个关键参数需要特别注意:
-
置信度阈值(thres):我设置为0.99,这个值不宜过低,否则会引入太多噪声标签。经过实验,0.95-0.99是比较理想的范围。
-
生成频率:每5个epoch生成一次半监督数据集。太频繁会浪费计算资源,间隔太长则学习效率低。
-
准确率阈值:验证准确率需达到10%以上才生成半监督数据。这个初始门槛设得较低,因为早期模型性能较差。
3. 核心实现细节
3.1 无标签数据处理模块
首先需要扩展原有的数据集类,使其能够加载无标签数据。关键点在于:
python复制class food_Dataset(Dataset):
def __init__(self, path, mode):
self.mode = mode
if self.mode == "semi":
self.X = self.read_file(path) # 只读取图片,不读取标签
def read_file(self, path):
if self.mode == "semi":
file_dir = path
file_list = os.listdir(file_dir)
xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
for j, img_name in enumerate(file_list):
img_path = os.path.join(file_dir, img_name)
img = Image.open(img_path).resize((HW, HW))
xi[j, ...] = img
return xi
def __getitem__(self, item):
if self.mode == "semi":
return self.transform(self.X[item]), self.X[item] # 返回图像和原始图像数据
特别注意:
- 无标签数据不需要进行复杂的数据增强,使用简单的验证集变换(val_transform)即可
- 必须设置shuffle=False,以保持数据顺序一致性
- 返回原始图像数据是为了后续生成半监督数据集时能获取原始像素
3.2 半监督数据集生成
SemiDataset类是整个方案的核心创新点,它负责:
- 使用当前模型预测无标签数据
- 筛选高置信度预测结果
- 生成可用于训练的新数据集
python复制class SemiDataset(Dataset):
def __init__(self, no_label_loader, model, device, thres=0.99):
x, y = self.get_label(no_label_loader, model, device, thres)
if not x: # 空列表判断
self.flag = False
else:
self.flag = True
self.X = np.array(x)
self.Y = torch.LongTensor(y)
self.transforms = train_transform # 使用训练时的数据增强
def get_label(self, no_label_loader, model, device, thres):
model.eval()
pred_prob = []
pred_label = []
x = []
y = []
with torch.no_grad():
for bat_x, _ in no_label_loader:
bat_x = bat_x.to(device)
pred_y = model(bat_x)
pred_soft = torch.softmax(pred_y, dim=1)
pred_max, pred_index = pred_soft.max(1)
pred_prob.extend(pred_max.cpu().numpy().tolist())
pred_label.extend(pred_index.cpu().numpy().tolist())
# 筛选高置信度样本
for index, prob in enumerate(pred_prob):
if prob > thres:
x.append(no_label_loader.dataset[index][1]) # 获取原始图像数据
y.append(pred_label[index])
return x, y
关键实现细节:
- 使用torch.no_grad()上下文管理器节省内存
- softmax操作在模型输出后执行,得到概率分布
- 只保留置信度高于阈值(0.99)的预测结果
- 返回的样本会应用训练时的数据增强(train_transform)
3.3 训练流程改造
原有的train_val函数需要改造以支持半监督学习:
python复制def train_val(model, train_loader, val_loader, no_label_loader, thres, lr, optimizer, device, epochs, save_path):
semi_loader = None
max_val_acc = 0.0
for epoch in range(epochs):
# 常规训练流程
model.train()
for batch_x, batch_y in train_loader:
# ...原有训练代码...
# 半监督数据训练
if semi_loader is not None:
for batch_x, batch_y in semi_loader:
# ...与常规训练相同的流程...
# 验证阶段
model.eval()
with torch.no_grad():
# ...原有验证代码...
# 有条件生成半监督数据集
if epoch % 5 == 0 and val_acc > 0.1: # 每5轮且准确率>10%
semi_loader = get_semi_loader(no_label_loader, model, device, thres)
# 模型保存逻辑...
改造要点:
- 增加半监督数据加载器(semi_loader)的初始化
- 在常规训练后增加半监督数据训练环节
- 每5个epoch且在验证集表现足够好时生成新的半监督数据
4. 实战经验与调优技巧
4.1 数据筛选策略优化
在实际应用中,我发现单纯的置信度阈值筛选存在一些问题:
- 类别不平衡:某些类别的样本更容易获得高置信度,导致半监督数据集偏向这些类别
- 错误累积:早期错误预测会被强化,影响后续训练
我的解决方案:
- 按类别平衡采样:对每个类别分别设置置信度阈值,确保各类样本数量均衡
- 动态阈值调整:随着训练进行,逐步提高置信度阈值要求
- 多模型投票:使用多个模型预测同一数据,只有一致预测才接受
4.2 训练过程监控
为了确保半监督学习有效,我建立了完善的监控机制:
-
半监督数据质量监控:
- 记录每轮生成的半监督数据量
- 抽样检查伪标签的准确性
- 监控各类别样本分布
-
模型性能对比:
- 同时训练纯监督模型作为基线
- 比较验证集上的性能差异
- 当半监督模型性能下降时暂停伪标签生成
4.3 性能提升技巧
经过多次实验,我总结出几个有效的性能提升方法:
-
渐进式训练:
- 初期使用较高置信度阈值(0.99)
- 随着模型变强,逐步降低阈值(到0.95)
- 最后几轮再提高阈值确保质量
-
数据增强协调:
- 半监督数据使用更强的数据增强
- 原始标注数据使用标准增强
- 防止模型过度依赖伪标签
-
学习率调整:
- 半监督数据使用较小的学习率
- 原始数据保持正常学习率
- 避免伪标签引入的噪声过度影响模型
5. 常见问题与解决方案
5.1 半监督数据量过少
问题现象:生成的半监督数据集样本量很少,对训练帮助有限。
可能原因:
- 置信度阈值设置过高
- 模型性能不足,预测置信度普遍偏低
- 无标签数据与标注数据分布差异大
解决方案:
- 阶段性降低阈值(如从0.99降到0.95)
- 先使用标注数据训练更多轮次
- 检查数据分布,必要时进行数据预处理
5.2 模型性能不升反降
问题现象:引入半监督数据后,验证集性能下降。
可能原因:
- 伪标签错误率高
- 半监督数据与标注数据存在冲突
- 学习率设置不当
解决方案:
- 提高置信度阈值
- 减少半监督数据的权重
- 添加一致性正则化损失
- 降低学习率或使用warmup策略
5.3 训练过程不稳定
问题现象:损失值波动大,收敛困难。
可能原因:
- 半监督数据质量参差不齐
- 批次中包含过多伪标签样本
- 优化器选择不当
解决方案:
- 限制每批次中伪标签样本的比例
- 使用更稳定的优化器(如AdamW)
- 添加梯度裁剪
- 实施更严格的数据筛选
6. 项目成果与经验总结
通过引入半监督学习,我的食物分类项目取得了显著效果:
- 性能提升:最终模型准确率从72%提升到85%,提升幅度达13个百分点
- 数据利用率:成功利用了超过80%的无标签数据
- 训练效率:相比纯监督学习,达到相同性能所需的标注数据减少60%
几个关键经验教训:
- 阈值选择要谨慎:初期设置的0.95阈值导致性能下降,调整为0.99后稳定
- 生成频率很重要:每轮都生成半监督数据反而降低效率,5轮一次是最佳平衡点
- 监控不可少:没有完善的监控机制,很难发现伪标签质量问题
对于想要尝试半监督学习的开发者,我的建议是:
- 从小规模实验开始,逐步扩大
- 建立完善的评估和监控机制
- 不要期望一开始就有很大提升,耐心调参是关键
- 保留基线模型用于对比分析