中草药识别一直是中医药数字化进程中的重要课题。传统的人工鉴别方法依赖药师经验,存在主观性强、效率低下等问题。我们基于PyTorch框架和EfficientNetV2网络,构建了一个能够自动识别5种常见中草药(百合、党参、山魈、枸杞、槐花、金银花)的深度学习系统。这个项目不仅实现了90%以上的分类准确率,还开发了用户友好的GUI界面,为中医药信息化提供了实用工具。
在医疗AI领域,图像分类技术的应用越来越广泛。相比传统CNN网络,EfficientNetV2在保持高精度的同时大幅提升了训练和推理速度,这对资源受限的中小型医疗机构特别有价值。我们的实现方案在消费级GPU上即可运行,单张图片的推理时间控制在200ms以内。
我们收集了6类共900张中草药高清图片,每类约150张。数据采集时特别注意了以下几点:
数据集目录结构如下:
code复制ChineseMedicine/
├── 百合/
│ ├── image_001.jpg
│ └── ...
├── 党参/
├── 山魈/
├── 枸杞/
├── 槐花/
└── 金银花/
为提高模型泛化能力,我们设计了以下数据增强方案:
python复制data_transform = {
"train": transforms.Compose([
transforms.RandomResizedCrop(300), # 随机裁剪缩放
transforms.RandomHorizontalFlip(), # 水平翻转
transforms.ColorJitter(brightness=0.2, contrast=0.2), # 颜色扰动
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]),
"val": transforms.Compose([
transforms.Resize(384),
transforms.CenterCrop(384),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
}
注意事项:中草药图像增强需要保持药材的关键特征不被破坏。例如,枸杞的纹理、金银花的形状等关键特征应在增强后仍然清晰可辨。
EfficientNetV2相比V1版本主要做了以下改进:
网络结构对比如下:
| 模块类型 | EfficientNetV1 | EfficientNetV2 |
|---|---|---|
| 基础模块 | MBConv | MBConv + Fused-MBConv |
| 扩展比例 | 固定6 | 动态调整(1-6) |
| 卷积核大小 | 主要5x5 | 主要3x3 |
| 训练策略 | 固定尺寸 | 渐进式学习 |
我们基于官方预训练模型进行微调,关键代码如下:
python复制from torchvision.models import efficientnet_v2_s
def create_model(num_classes=6):
model = efficientnet_v2_s(pretrained=True)
# 修改最后一层全连接
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
return model
实操技巧:冻结除分类头外的所有层可以显著加快训练速度。当验证集准确率停滞时再解冻部分深层网络。
我们采用以下超参数设置:
学习率调度策略采用余弦退火:
python复制lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
使用TensorBoard记录训练指标:
python复制tb_writer.add_scalar("train_loss", train_loss, epoch)
tb_writer.add_scalar("val_acc", val_acc, epoch)
典型训练曲线显示:
将训练好的PyTorch模型转换为TorchScript格式,便于生产环境部署:
python复制model = create_model()
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
traced_script = torch.jit.script(model)
traced_script.save('medicine_classifier.pt')
GUI主要功能模块:
核心预测代码:
python复制def predict_image(image_path):
img = Image.open(image_path)
img_tensor = val_transform(img).unsqueeze(0)
with torch.no_grad():
output = model(img_tensor)
prob = torch.nn.functional.softmax(output, dim=1)
pred_idx = torch.argmax(prob).item()
return class_names[pred_idx], prob[0][pred_idx].item()
CUDA内存不足:
python复制scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
outputs = model(inputs)
类别不平衡:
推理速度慢:
在实际部署中,我们发现几个有价值的改进点:
这个项目完整代码已开源,包含训练脚本、预训练模型和GUI实现。对于想入门医疗AI的开发者,建议先从少量药材类别开始,逐步扩展分类体系。在实际应用中,持续收集真实场景数据对模型迭代至关重要。