在计算机视觉领域,卷积神经网络(CNN)长期占据主导地位。2020年,Google Research团队发表论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》,首次将纯Transformer架构成功应用于图像分类任务,这就是Vision Transformer(ViT)的诞生。ViT完全摒弃了传统CNN的卷积操作,仅使用标准的Transformer编码器处理图像数据,在多个基准数据集上达到了与当时最优CNN模型相当甚至更好的性能。
ViT的核心创新在于将图像视为一系列"视觉词元"(visual tokens)。就像自然语言处理中把句子拆分为单词一样,ViT将输入图像分割为固定大小的图像块(patches),每个块经过线性投影后成为Transformer的输入序列。这种处理方式使得原本为序列数据设计的Transformer架构能够直接处理二维图像数据,而无需引入任何针对视觉任务的特定归纳偏置(如CNN的局部感受野和平移不变性)。
关键突破:ViT证明了在大规模数据集预训练条件下,纯Transformer架构在视觉任务中可以超越精心设计的CNN模型,这颠覆了计算机视觉领域长期以来的设计范式。
ViT处理图像的第一步是将二维图像转换为适合Transformer处理的一维序列。具体实现方式如下:
图像分块:假设输入图像大小为H×W×C(高度×宽度×通道数),ViT将其分割为N个大小为P×P×C的图像块。每个块在展平后将变为长度为P²C的向量。例如,对于224×224×3的ImageNet图像,使用16×16的分块大小,将得到(224/16)²=196个图像块。
线性投影:通过可训练的线性投影层(全连接层)将每个展平的图像块映射到模型维度D。这个投影层实际上等同于一个步长等于块大小的P×P卷积核,其输出通常称为"patch embeddings"。
位置编码:与原始Transformer类似,ViT需要添加位置信息以保留图像的空间结构。位置编码可以是标准的可学习1D位置编码,也可以是更复杂的2D-aware编码。这些编码与patch embeddings相加,形成最终的输入序列。
python复制# PyTorch风格的伪代码实现
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size,
stride=patch_size)
def forward(self, x):
x = self.proj(x) # (B, C, H, W) -> (B, D, H/P, W/P)
x = x.flatten(2).transpose(1, 2) # (B, D, N) -> (B, N, D)
return x
ViT使用的Transformer编码器与原始Transformer几乎完全相同,由交替的多头自注意力(MSA)和多层感知机(MLP)块组成,每个块前应用层归一化(LayerNorm),后接残差连接:
多头自注意力机制:将输入序列划分为h个"头",在每个头上独立计算查询(Q)、键(K)、值(V)的注意力权重。这种分头机制允许模型在不同表示子空间中共同关注来自不同位置的信息。
自注意力计算公式:
[
\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
]
其中d_k是键向量的维度。
MLP块:通常由两个全连接层组成,中间包含GELU非线性激活。第一个层将维度扩展(通常4倍),第二个层将其投影回原始维度。
层归一化与残差连接:每个子层(MSA和MLP)都采用残差结构,有助于缓解深层网络的梯度消失问题。公式表示为:
[
z_{l+1} = \text{MLP}(\text{LayerNorm}(z_l')) + z_l' \
z_l' = \text{MSA}(\text{LayerNorm}(z_l)) + z_l
]
ViT在序列开始处添加了一个可学习的[class]标记(类似于BERT的[CLS]标记),该标记的最终状态被用作整个图像的表示,输入到分类头中进行预测:
可学习的分类标记:在patch embeddings前拼接一个随机初始化的向量,作为整个序列的全局表示。这个标记在训练过程中会学习整合整个图像的信息。
MLP分类头:通常由一个层归一化层和一个线性层组成。在预训练的大模型中,有时会使用更复杂的头部设计,但微调时简单线性层通常就足够。
python复制# 分类头实现示例
class VisionTransformer(nn.Module):
def __init__(self, num_classes=1000):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.blocks = nn.ModuleList([TransformerBlock() for _ in range(depth)])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return self.head(x[:, 0])
ViT论文的一个重要发现是:当训练数据不足时(如只在ImageNet上训练),ViT的表现通常不如同等大小的CNN模型。这是因为:
实验表明,当使用足够大的数据集(如JFT-300M,包含3亿张图像)预训练时,ViT才能展现出超越CNN的优势。这引出了ViT的典型使用范式:先在大型数据集上预训练,然后在目标数据集上微调。
在目标数据集上微调ViT时,有几个关键考虑因素:
分辨率调整:微调时通常使用比预训练时更高的图像分辨率。这需要调整位置编码,通常采用双三次插值来适应新的序列长度。
学习率设置:通常使用分层学习率,分类头使用更高的学习率(通常是基础学习率的10倍),因为它是随机初始化的。
正则化技术:常用权重衰减、dropout和随机深度(stochastic depth)。特别是随机深度,它在训练过程中随机跳过某些Transformer块,起到正则化作用。
实践技巧:当目标数据集较小时,冻结所有Transformer层,只训练分类头往往也能获得不错的结果,这可以防止过拟合。
训练大型ViT模型需要特别的优化技术:
混合精度训练:使用FP16精度可以显著减少显存占用并加速训练。现代框架如PyTorch的AMP(自动混合精度)可以自动管理精度转换。
梯度检查点:通过只保存部分激活值并在反向传播时重新计算中间结果,可以大幅减少显存使用(约60-70%),代价是增加约30%的计算时间。
数据并行策略:对于非常大的模型(如ViT-Huge),需要使用模型并行技术,如将注意力头或矩阵乘法操作分布到多个设备上。
原始ViT的全局注意力计算复杂度与图像块数量的平方成正比(O(N²)),这限制了其在更高分辨率图像上的应用。后续研究提出了多种改进方案:
DeiT(Data-efficient Image Transformer):通过知识蒸馏和更好的训练策略,使ViT可以在ImageNet级别数据集上有效训练,无需超大规模预训练。
Swin Transformer:引入层次化特征图和局部窗口注意力,计算复杂度降为线性(O(N)),更适合密集预测任务如目标检测和分割。
PVT(Pyramid Vision Transformer):构建特征金字塔,在不同尺度上处理特征,适用于需要多尺度特征的任务。
与CNN类似,ViT也可以应用于自监督学习场景:
MAE(Masked Autoencoder):随机mask掉大部分图像块(如75%),然后训练模型重建被mask的区域。这种方法可以学习到强大的视觉表示。
MoCo v3:将对比学习应用于ViT,通过最大化同一图像的不同augmentation之间的一致性来学习表示。
DINO:通过自蒸馏方法,在没有任何标签的情况下训练ViT,学习到的特征可以直接用于图像分割等任务。
ViT的灵活性使其可以轻松扩展到多模态任务:
CLIP:同时训练图像ViT和文本Transformer,通过对比学习对齐两种模态的表示空间,实现强大的零样本分类能力。
Flamingo:将预训练的视觉和语言模型结合起来,处理复杂的图文交互任务。
BEiT:统一了图像和文本的表示学习框架,使用共享的Transformer架构处理两种模态。
理解ViT如何"看到"图像是一个重要课题。常用的分析方法包括:
注意力权重可视化:展示[class]标记对不同图像块的注意力分布,揭示模型关注哪些区域进行决策。
注意力流分析:跟踪信息如何在不同的注意力头之间流动,理解模型内部的推理过程。
遮挡测试:系统地遮挡图像的不同部分,观察对模型输出的影响。
研究发现:ViT的浅层注意力通常较为局部,类似于CNN;而深层注意力则表现出明显的语义相关性,能够关联图像中语义相似但空间分离的区域。
ViT模型对计算资源的需求显著高于传统CNN:
模型大小:标准ViT模型参数数量从Base(86M)到Large(307M)再到Huge(632M)不等,更大的模型通常需要分布式训练。
内存占用:注意力矩阵的显存需求与序列长度平方成正比,处理高分辨率图像时需要特别优化。
推理延迟:虽然ViT的FLOPs可能与CNN相当,但由于注意力机制的内存访问模式,实际推理速度可能更慢。
将ViT应用于特定领域时可能遇到的挑战:
医学图像分析:医学图像通常具有与自然图像完全不同的统计特性,直接应用预训练ViT可能效果不佳。解决方案包括领域特定的预训练或适配器模块。
遥感图像:超高分辨率图像需要特殊的分块策略,可能需要结合CNN进行局部特征提取。
视频处理:直接将ViT扩展到视频会面临极大的计算开销,需要开发高效的时空注意力变体。
CNN和ViT的根本区别在于它们内置的归纳偏置(模型对数据结构的假设):
| 特性 | CNN | ViT |
|---|---|---|
| 局部性 | 强(卷积核的有限感受野) | 无(全局注意力) |
| 平移等变性 | 强 | 弱(通过位置编码学习) |
| 尺度不变性 | 弱(需要多尺度处理) | 理论上可以学习 |
| 参数效率 | 高(权重共享) | 较低(注意力权重不共享) |
在标准基准测试中的表现对比:
ImageNet分类:在大规模预训练后,ViT通常能比同等FLOPs的CNN高出1-2%的top-1准确率。
迁移学习:ViT在跨领域迁移时通常表现更好,特别是在目标领域与源领域差异较大时。
对抗鲁棒性:研究发现ViT通常比CNN对对抗样本更鲁棒,可能因为其全局注意力机制更难被局部扰动欺骗。
根据任务特点选择适合的架构:
选择CNN的场景:
选择ViT的场景:
推荐使用PyTorch和HuggingFace的Transformers库:
bash复制pip install torch torchvision transformers timm
数据加载使用标准的ImageFolder格式:
python复制from torchvision import datasets, transforms
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = datasets.ImageFolder(
'path/to/train',
transform=train_transform
)
使用timm库可以方便地加载预训练ViT模型:
python复制import timm
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=1000)
自定义ViT实现的核心部分:
python复制class ViTBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim)
)
def forward(self, x):
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.mlp(self.norm2(x))
return x
典型的训练循环结构:
python复制def train_epoch(model, loader, optimizer, criterion, device):
model.train()
total_loss = 0
for inputs, targets in loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
评估函数示例:
python复制@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
correct = 0
total = 0
for inputs, targets in loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
correct += predicted.eq(targets).sum().item()
total += targets.size(0)
return correct / total
推理时单张图像处理:
python复制def predict(image_path, model, transform, device):
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
return probabilities.cpu().numpy()
ViT训练对学习率非常敏感,推荐策略:
python复制from torch.optim import AdamW
optimizer = AdamW([
{'params': model.cls_token, 'lr': lr * 10},
{'params': model.pos_embed, 'lr': lr * 10},
{'params': model.head.parameters(), 'lr': lr * 10},
{'params': model.blocks.parameters(), 'lr': lr}
], weight_decay=0.01)
随机深度(Stochastic Depth):每个Transformer块有一定概率被跳过
python复制def forward(self, x):
if self.training and torch.rand(1) < self.drop_prob:
return x
return x + self.mlp(self.norm2(x + self.attn(self.norm1(x))))
MixUp和CutMix:图像数据增强技术,对ViT特别有效
python复制from timm.data import Mixup
mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0)
inputs, targets = mixup_fn(inputs, targets)
Label Smoothing:减轻模型过度自信
python复制criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
梯度检查点:
python复制from torch.utils.checkpoint import checkpoint
x = checkpoint(block, x)
激活检查点:
python复制torch.utils.checkpoint.checkpoint_sequential(
[block for block in model.blocks],
chunks=4,
input=x
)
混合精度训练:
python复制from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
现象:损失出现NaN或突然增大
解决方案:
torch.nn.utils.clip_grad_norm_)现象:训练准确率高但验证准确率低
解决方案:
现象:注意力图过于分散或只关注极小区域
解决方案:
现象:预训练模型微调后性能提升有限
解决方案:
在实际项目中,我发现ViT对超参数的选择比CNN更为敏感,特别是学习率和warmup步数需要精心调整。另一个实用技巧是在微调时逐步解冻网络层,从分类头开始,然后逐渐解冻更深的Transformer块,这通常比一次性微调所有层效果更好。对于计算资源有限的情况,从较小的ViT模型(如ViT-Tiny或ViT-Small)开始,配合适当的数据增强,往往能在有限资源下获得不错的性能。