2017年Transformer架构在NLP领域大获成功后,谁也没想到这个基于自注意力机制的模型会在计算机视觉领域掀起更大风暴。传统卷积神经网络(CNN)统治图像识别十余年后,Vision Transformer(ViT)的出现彻底打破了人们对视觉处理的认知边界。我第一次在ICLR 2021看到这篇论文时,就被其"纯Transformer不靠CNN也能做图像分类"的大胆假设震撼了。
ViT的核心创新在于将图像处理转化为序列建模问题。就像我们把句子拆分成单词token一样,ViT将224x224的图像切割成16x16的196个图像块(patch),每个patch展平后就是长度为768的向量(假设使用ViT-Base模型)。这种看似简单的处理方式,实则颠覆了传统视觉任务必须依赖局部感受野的固有认知。
关键洞见:当训练数据足够大时(JFT-300M数据集),Transformer在图像分类任务上可以超越当时最先进的CNN模型。这个发现直接催生了后续的Swin Transformer、DeiT等一系列视觉Transformer变体。
ViT的输入处理流程堪称精妙:
python复制# 简化版Patch Embedding实现
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size,
stride=patch_size)
def forward(self, x):
x = self.proj(x) # (B, C, H, W) -> (B, D, H/P, W/P)
x = x.flatten(2).transpose(1, 2) # (B, D, N) -> (B, N, D)
return x
ViT的编码器由L个相同的Transformer Block堆叠而成,每个Block包含:
自注意力机制的计算过程:
python复制# Transformer Block实现示例
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads=num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = Mlp(dim, hidden_dim=int(dim*mlp_ratio))
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
原始ViT需要JFT-300M这样的大规模数据集才能发挥优势,这对普通研究者极不友好。后续研究提出了几种改进方案:
知识蒸馏:DeiT模型使用CNN教师模型指导ViT训练
数据增强:
正则化策略:
| 参数 | ViT-Base | ViT-Large | ViT-Huge |
|---|---|---|---|
| Layers | 12 | 24 | 32 |
| Hidden Size D | 768 | 1024 | 1280 |
| MLP Size | 3072 | 4096 | 5120 |
| Heads | 12 | 16 | 16 |
| Params | 86M | 307M | 632M |
使用HuggingFace的transformers库快速搭建ViT:
python复制from transformers import ViTModel, ViTConfig
config = ViTConfig(
image_size=224,
patch_size=16,
num_classes=1000,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12
)
model = ViTModel(config)
# 自定义分类头
class ViTForImageClassification(nn.Module):
def __init__(self, config):
super().__init__()
self.vit = ViTModel(config)
self.classifier = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, x):
outputs = self.vit(x)
logits = self.classifier(outputs.last_hidden_state[:,0])
return logits
ViT在小数据集上的微调策略:
bash复制# 典型训练命令示例
python train.py \
--model vit_base_patch16_224 \
--batch_size 64 \
--lr 1e-4 \
--epochs 30 \
--warmup_epochs 5 \
--weight_decay 0.05
现象:训练时出现CUDA out of memory错误
解决方法:
检查清单:
加速技巧:
实测对比(Tesla T4):
| 方法 | 延迟(ms) | 显存(MB) |
|---|---|---|
| FP32 | 45.2 | 1240 |
| FP16 | 28.7 | 820 |
| INT8量化 | 19.4 | 610 |
Swin Transformer的创新:
python复制# Swin Transformer Block示例
class SwinBlock(nn.Module):
def __init__(self, dim, input_resolution, num_heads):
super().__init__()
self.w_msa = WindowMSA(dim, input_resolution, num_heads)
self.sw_msa = ShiftedWindowMSA(dim, input_resolution, num_heads)
def forward(self, x):
x = self.w_msa(x)
x = self.sw_msa(x)
return x
最新改进方向:
性能对比(ImageNet-1K):
| 模型 | 参数量(M) | Top-1 Acc |
|---|---|---|
| ViT-B/16 | 86 | 77.9 |
| DeiT-S | 22 | 79.8 |
| Swin-T | 29 | 81.3 |
| MobileViT-S | 5.6 | 78.4 |
ViT在医疗领域的创新应用:
实践发现:在数据量不足的医疗领域,采用预训练-微调范式时,使用自然图像预训练的ViT反而比医学专用CNN表现更好,这颠覆了传统认知。
时空ViT变体:
python复制# 视频patch embedding示例
class VideoEmbed(nn.Module):
def __init__(self, temp_kernel=3):
super().__init__()
self.proj = nn.Conv3d(3, 768,
kernel_size=(temp_kernel,16,16),
stride=(1,16,16))
def forward(self, x):
# x: (B,C,T,H,W)
x = self.proj(x) # (B,D,T,H/P,W/P)
x = x.flatten(3).transpose(1,2) # (B,T,N,D)
return x
虽然ViT已经展现出惊人潜力,但在实际工业落地时仍面临三大挑战:
我在多个项目中的体会是:ViT与传统CNN并非替代关系,而是互补工具。当前最佳实践往往是: