1. 视觉Transformer(ViT)核心思想解析
视觉Transformer(Vision Transformer,简称ViT)的核心创新在于彻底摒弃了传统计算机视觉中卷积神经网络(CNN)的固有架构,将自然语言处理(NLP)中成功的Transformer模型直接应用于图像数据。这种看似激进的做法背后有着深刻的洞见:
1.1 图像到序列的转换策略
ViT处理图像的关键第一步是将二维图像结构转化为一维序列。具体实现方式如下:
-
图像分块处理:假设输入图像尺寸为H×W×C(高度×宽度×通道数),ViT将其分割为N个尺寸为P×P的正方形图像块。例如,对于224×224的输入图像,若采用16×16的图像块大小,则得到N=196个图像块。
-
线性投影:每个图像块被展平为P²·C维的向量,然后通过可训练的线性投影矩阵E映射到D维空间(典型值D=768)。这个投影过程可以理解为将每个图像块转换为一个"视觉词元"(visual token),类似于NLP中将单词转换为词向量。
-
位置编码:由于Transformer本身不具备处理序列顺序的能力,ViT引入了可学习的一维位置编码,为每个图像块添加位置信息。有趣的是,尽管图像本质上是二维结构,但实验表明简单的一维位置编码已经足够,更复杂的二维编码并未带来明显提升。
1.2 类BERT的架构设计
ViT的架构设计大量借鉴了BERT的成功经验:
-
[class]标记:在图像块序列前添加一个可学习的[class]标记,其最终输出状态作为整个图像的表示。这与BERT中使用[CLS]标记进行句子分类的思路完全一致。
-
Transformer编码器:使用标准的Transformer编码器堆叠,每个编码器层包含多头自注意力机制(MSA)和多层感知机(MLP),并采用层归一化(LayerNorm)和残差连接。
-
预训练+微调范式:先在大型数据集(如JFT-300M)上进行预训练,然后在目标任务(如ImageNet)上进行微调。微调时通常使用更高分辨率输入以提升性能。
1.3 与CNN的关键差异
ViT与CNN在归纳偏置(inductive bias)方面存在根本区别:
-
局部性:CNN通过卷积核大小(如3×3)显式约束感受野范围,强制模型关注局部特征。而ViT的自注意力机制从第一层开始就可以关注全局信息。
-
平移等变性:CNN的卷积操作天然具有平移等变性(物体移动后特征表示不变),而ViT需要从头学习这种特性。
-
二维结构:CNN通过滑动窗口操作隐式保持图像的二维结构,ViT则需要通过位置编码来学习空间关系。
这些差异使得在小规模数据集上,CNN通常优于ViT;但当数据量足够大时(如JFT-300M),ViT可以超越CNN,表明大规模数据训练可以弥补缺乏显式归纳偏置的不足。
2. ViT架构实现细节
2.1 输入处理流程
ViT的输入处理流程包含几个关键步骤:
-
图像分块:使用
torch.nn.Unfold等操作将图像分割为不重叠的块。例如,对于224×224的RGB图像,采用16×16的块大小,得到196个768维的向量(16×16×3=768)。 -
线性投影:通过
nn.Linear层将每个图像块投影到模型维度D。这个投影矩阵E是可学习的参数,形状为(P²·C)×D。 -
位置编码:位置编码E_pos的形状为(N+1)×D(N是图像块数量,+1对应[class]标记)。在实现时,可以初始化为正态分布随机数,然后随模型一起训练。
python复制# PyTorch风格的伪代码
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size,
stride=patch_size)
def forward(self, x):
x = self.proj(x) # (B, C, H, W) -> (B, D, H/P, W/P)
x = x.flatten(2) # -> (B, D, N)
x = x.transpose(1, 2) # -> (B, N, D)
return x
2.2 Transformer编码器实现
ViT的Transformer编码器与标准实现几乎完全相同,主要包含以下组件:
-
多头自注意力(MSA):计算查询(Q)、键(K)、值(V)的注意力权重。ViT通常使用12个头,每个头的维度为64(对于D=768的情况)。
-
MLP块:包含两个全连接层,中间使用GELU激活函数。通常第一个层将维度扩展到4D(如768→3072),第二个层压缩回D(3072→768)。
-
层归一化与残差连接:每个子层(MSA和MLP)前应用层归一化,后接残差连接。
python复制class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = MultiHeadAttention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, int(dim*mlp_ratio))
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
2.3 混合架构实现
ViT还提出了混合架构(Hybrid Architecture),结合了CNN和Transformer的优点:
-
CNN特征提取:使用ResNet等CNN网络提取图像特征图,替代原始图像块。
-
特征图序列化:将CNN输出的特征图(如14×14×1024)展平为序列(196×1024),然后投影到Transformer维度。
-
Transformer处理:后续流程与标准ViT相同。
这种设计在小规模数据集上表现更好,因为CNN的局部性偏置有助于缓解数据不足的问题。
3. 训练策略与技巧
3.1 预训练配置
ViT的成功很大程度上依赖于大规模预训练。关键配置包括:
-
优化器:使用AdamW优化器(β₁=0.9,β₂=0.999),权重衰减设为0.1。这与CNN常用的SGD不同,AdamW对Transformer训练更稳定。
-
学习率调度:采用线性预热(linear warmup)和余弦衰减(cosine decay)。典型配置是10,000步预热,总训练步数500,000。
-
正则化:
- Dropout率:0.1(注意力权重和MLP)
- 标签平滑:0.1
- 随机深度(stochastic depth):对于深层模型(如ViT-H/14)使用0.1的概率随机跳过某些层
-
批量大小:通常使用4096的大批量训练,配合梯度累积(gradient accumulation)策略。
3.2 微调技巧
在目标任务上微调ViT时,有几个关键技巧:
-
更高分辨率:微调时输入分辨率通常高于预训练(如预训练224,微调384或512)。这需要:
- 保持图像块大小不变(如16×16),导致序列长度增加
- 对位置编码进行双线性插值,适应新的网格大小
-
优化器选择:微调通常使用带动量的SGD(如momentum=0.9),比AdamW表现更好。
-
学习率策略:采用较小的基础学习率(如0.003),配合余弦衰减。
-
权重平均:使用Polyak-Ruppert平均(EMA)可以提升最终性能约0.5%。
3.3 实际训练经验
在实际训练ViT时,我们发现以下经验特别重要:
-
学习率预热:Transformer训练对初始学习率敏感,必须充分预热。我们通常设置5-10%的训练步数用于预热。
-
梯度裁剪:即使使用AdamW,梯度裁剪(如clipnorm=1.0)也有助于稳定训练。
-
混合精度训练:使用AMP(Automatic Mixed Precision)可以显著减少显存占用并加速训练,但要注意:
- 保持主权重(master weights)为FP32
- 对LayerNorm使用FP32精度
-
硬件利用:ViT在TPU上训练效率最高,GPU上需要注意:
- 使用
torch.scaled_dot_product_attention优化注意力计算 - 适当增大批量大小以提高硬件利用率
- 使用
4. 性能分析与优化
4.1 计算效率分析
ViT相比CNN在计算效率上有独特特点:
-
FLOPs比较:ViT-B/16(D=768)与ResNet50的FLOPs相近(约10G),但ViT通常需要更少的训练迭代次数。
-
内存占用:ViT的自注意力层内存需求与序列长度平方成正比(O(N²)),这限制了最大输入分辨率。
-
实际速度:在GPU上,ViT的吞吐量通常低于相同FLOPs的CNN,因为:
- 自注意力操作对硬件不友好
- 长序列导致内存带宽成为瓶颈
4.2 模型缩放规律
ViT论文研究了不同规模的模型变体:
| 模型变体 | 层数 | 隐藏维度D | MLP大小 | 头数 | 参数量 |
|---|---|---|---|---|---|
| ViT-Base | 12 | 768 | 3072 | 12 | 86M |
| ViT-Large | 24 | 1024 | 4096 | 16 | 307M |
| ViT-Huge | 32 | 1280 | 5120 | 16 | 632M |
缩放规律表明:
- 增大模型尺寸持续提升性能,尚未观察到饱和
- 计算开销增加与性能提升基本呈线性关系
- 在足够大数据集上,越大模型优势越明显
4.3 注意力机制分析
通过可视化注意力权重,我们发现:
-
低层注意力:部分头表现出局部性,类似于CNN的卷积操作;另一些头则关注全局信息。
-
高层注意力:注意力模式与语义内容相关,例如:
- 分类头关注判别性区域
- 某些头专门关注物体边界
- 背景和前景通常被不同头处理
-
注意力距离:随着网络深度增加,平均注意力距离逐渐增大,表明高层整合更全局的信息。
4.4 实际部署考量
在实际部署ViT时需要考虑:
-
输入分辨率:更高的分辨率提升性能但增加计算开销,需要权衡。常见选择:
- 移动端:224-384
- 服务器端:384-512
-
量化:ViT对8bit量化友好,精度损失通常<1%。可采用:
- PTQ(训练后量化)
- QAT(量化感知训练)
-
剪枝:可以移除部分注意力头或MLP维度,压缩模型大小。通常:
- 低层的局部注意力头更重要
- 高层的全局注意力头冗余度更高
-
编译器优化:使用TensorRT等工具可以显著提升推理速度,特别是通过:
- 融合LayerNorm和残差连接
- 优化矩阵乘法顺序
- 利用Flash Attention等优化技术
5. 应用扩展与未来方向
5.1 超越图像分类
ViT的思想可以扩展到各种视觉任务:
-
目标检测:将ViT作为特征提取器,配合检测头(如DETR)。关键挑战是处理高分辨率特征图的高计算成本。
-
语义分割:采用编码器-解码器结构,使用ViT作为编码器。需要设计高效的上采样策略。
-
视频理解:将时间维度作为额外序列,构建时空Transformer。计算复杂度成为主要瓶颈。
-
多模态任务:联合处理图像和文本,如CLIP模型。ViT作为视觉编码器表现优异。
5.2 自监督学习
ViT在自监督学习方面有巨大潜力:
-
掩码图像建模:类似BERT的掩码语言建模,随机掩码图像块并预测缺失内容。这种方法(如MAE)已取得很好效果。
-
对比学习:如MoCo v3,将ViT作为编码器学习不变特征表示。
-
蒸馏方法:使用教师-学生框架,让小模型从大ViT中学习。
5.3 架构改进方向
未来可能的改进方向包括:
-
高效注意力机制:如稀疏注意力、轴向注意力等,降低O(N²)复杂度。
-
层次化设计:引入类似CNN的层次结构,逐步降低分辨率。
-
动态计算:根据输入内容自适应调整计算量,如跳过某些层或注意力头。
-
神经架构搜索:自动搜索最优的Transformer配置,如层数、头数等。
ViT代表了计算机视觉领域的一次范式转变,它证明了纯Transformer架构在视觉任务中的潜力。随着硬件和算法的进步,ViT及其变体有望在更多视觉应用中取代传统CNN,成为新一代的基础模型架构。