计算机视觉领域近年来最激动人心的变革,莫过于Transformer架构的跨界应用。2017年诞生的Transformer原本是为自然语言处理设计的,但2020年Google Research发表的《An Image is Worth 16x16 Words》彻底改变了游戏规则。本文将深入剖析视觉Transformer的核心技术路线,重点解读ViT、MAE、DETR和Deformable DETR四大里程碑模型,揭示它们如何重塑图像理解范式。
ViT(Vision Transformer)的核心突破在于将图像视为由16x16像素块组成的"视觉词序列"。这种处理方式完全摒弃了传统CNN的归纳偏置(如局部性、平移不变性),纯粹依靠注意力机制建立全局关系。具体实现包含四个关键步骤:
图像分块处理:将224x224的输入图像划分为14x14个16x16的patch(共196个token),每个patch通过线性投影变为768维向量(Base版本)
位置编码注入:采用可学习的1D位置编码,为每个patch添加空间位置信息。这与原始Transformer的固定正弦编码不同,实验表明在图像领域可学习编码更具优势
类别token引入:借鉴BERT的[CLS]token,用于聚合全局信息。最终该token的输出作为图像表征,用于下游分类任务
Transformer编码器:标准的多头自注意力结构,包含12层(Base版本),每层包含MSA(多头注意力)和MLP(前馈网络)模块
关键理解:ViT的成功证明了当数据量足够大时(需在JFT-300M等超大数据集预训练),纯粹的注意力机制可以超越精心设计的CNN归纳偏置
在实际应用中,我们发现了几个影响ViT性能的关键因素:
patch大小选择:16x16是精度与计算量的平衡点。32x32会显著降低计算量但损害细粒度特征,8x8则大幅增加计算开销(序列长度变为784)
混合架构尝试:在中小规模数据集上,可采用CNN+Transformer混合结构(如用CNN stem生成feature map再输入Transformer),能缓解数据不足问题
学习率策略:ViT对学习率非常敏感,推荐使用线性warmup(10k步)配合cosine衰减。Base模型初始lr建议3e-4,Large模型则需降至1e-4
python复制# ViT的patch嵌入层典型实现
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size,
stride=patch_size)
def forward(self, x):
x = self.proj(x) # (B, C, H, W) -> (B, D, H/P, W/P)
x = x.flatten(2).transpose(1, 2) # (B, D, N) -> (B, N, D)
return x
MAE(Masked Autoencoder)将NLP中的掩码语言建模成功迁移到视觉领域,其核心创新在于:
非对称编码-解码设计:编码器仅处理25%的可见patch,轻量级解码器则重建全部patch。这种设计大幅降低了计算成本
高掩码比例:75%的掩码率远高于BERT的15%,迫使模型学习更强的语义表征
像素级重建目标:直接预测归一化后的像素值,而非离散token,保持模型通用性
MAE的PyTorch风格伪代码揭示其精妙之处:
python复制class MAE(nn.Module):
def forward_encoder(self, x, mask_ratio=0.75):
# 1. 嵌入与位置编码
x = self.patch_embed(x) + self.pos_embed[:, 1:, :]
# 2. 随机掩码生成
ids_shuffle = torch.randperm(N, device=x.device)
ids_keep = ids_shuffle[:int(N*(1-mask_ratio))]
x_masked = torch.gather(x, dim=1, index=ids_keep)
# 3. 添加CLS token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
x = torch.cat([cls_token, x_masked], dim=1)
# 4. Transformer编码
return self.blocks(x), ids_restore
def forward_decoder(self, x, ids_restore):
# 1. 嵌入转换
x = self.decoder_embed(x)
# 2. 补全mask tokens
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(x_, dim=1, index=ids_restore)
x = torch.cat([x[:, :1, :], x_], dim=1)
# 3. 添加解码器位置编码
x = x + self.decoder_pos_embed
# 4. Transformer解码
return self.decoder_pred(self.decoder_blocks(x))
norm_pix_loss(对每个patch内部进行标准化)能提升约0.5%的线性探测准确率DETR(Detection Transformer)的革命性在于:
Object Queries是DETR最精妙的设计:
python复制# 典型实现方式
self.query_embed = nn.Embedding(num_queries, hidden_dim)
outputs = decoder(queries=self.query_embed.weight.unsqueeze(0).repeat(bs, 1, 1))
这些可学习的查询向量在训练过程中会自发形成空间分工:
传统Transformer的计算复杂度随图像尺寸平方增长,而Deformable Attention通过:
数学表达为:
$$
\text{DeformAttn}(q,p) = \sum_{m=1}^M W_m \left[ \sum_{k=1}^K A_{mqk} \cdot W_m' x(p + \Delta p_{mqk}) \right]
$$
python复制class DeformableAttention(nn.Module):
def forward(self, query, reference_points, value):
# 1. 预测采样偏移和注意力权重
offset = self.offset_proj(query) # (B, Nq, M*K*2)
attn = self.attn_proj(query) # (B, Nq, M*K)
# 2. 多尺度采样
sampled_value = multi_scale_sampling(value, reference_points, offset)
# 3. 加权聚合
return torch.einsum('bnmk,bnmd->bnmd', attn.softmax(dim=-1), sampled_value)
| 场景 | 推荐模型 | 理由 |
|---|---|---|
| 大数据预训练 | ViT或MAE | 纯Transformer架构上限高 |
| 中小规模分类 | Swin Transformer | 层次化设计更高效 |
| 目标检测 | Deformable DETR | 收敛快、精度高 |
| 实时应用 | MobileViT | 优化移动端部署 |
训练不稳定:
显存不足:
小目标检测差:
在实际项目中,我们发现视觉Transformer虽然强大但仍需谨慎使用。对于计算资源有限的团队,建议从Swin Transformer或MobileViT等高效架构入手,逐步探索更复杂的模型。记住,没有放之四海而皆准的解决方案,关键是根据任务需求找到精度与效率的最佳平衡点。