2017年Transformer架构在NLP领域大获成功后,研究者们开始探索其在计算机视觉领域的应用潜力。传统CNN通过局部感受野逐步构建特征层次,而ViT的核心创新在于将图像视为序列数据,用全局注意力机制捕捉长距离依赖关系。这种范式转换带来了三个显著优势:
ViT将输入图像分割为N个固定大小的patch(典型值为16×16像素),每个patch通过线性投影转换为D维嵌入向量。这个过程的数学表达为:
python复制# 假设输入图像尺寸为H×W×C
patch_size = 16
num_patches = (H * W) // (patch_size ** 2)
projection = nn.Linear(patch_size**2 * C, D)
实际操作时,可以通过卷积高效实现:
python复制self.proj = nn.Conv2d(in_channels=C,
out_channels=D,
kernel_size=patch_size,
stride=patch_size)
关键细节:位置编码采用可学习的1D向量而非原始Transformer的固定编码,这是因为2D位置信息在patch展开为1D序列时会部分丢失空间关系。
当训练数据有限时,可以采用CNN-ViT混合架构:
这种设计在ImageNet-1k上比纯ViT节省约40%训练数据量。
使用PyTorch实现基础ViT模型的核心组件:
python复制class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12):
super().__init__()
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads=12) for _ in range(depth)
])
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x) # [B, num_patches, embed_dim]
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
return x[:, 0] # 返回class token作为图像表示
不同规模的ViT模型典型配置:
| 模型变体 | 层数 | 隐藏层维度 | MLP大小 | 头数 | 参数量 |
|---|---|---|---|---|---|
| ViT-Base | 12 | 768 | 3072 | 12 | 86M |
| ViT-Large | 24 | 1024 | 4096 | 16 | 307M |
| ViT-Huge | 32 | 1280 | 5120 | 16 | 632M |
训练技巧:对于中小型数据集,建议:
- 使用AdamW优化器(β1=0.9, β2=0.999)
- 学习率线性warmup(10-20个epoch)
- 权重衰减0.05
- 标签平滑系数0.1
处理非标准尺寸图像的两种方案:
python复制def adaptive_patching(image, target_pixels=256):
h, w = image.shape[1:]
patch_size = int((h * w / target_pixels) ** 0.5)
patch_size = max(4, patch_size - patch_size % 4) # 保持能被4整除
return patch_size
python复制def smart_pad(image, target_size=224):
# 保持长宽比进行填充
_, h, w = image.shape
scale = min(target_size/h, target_size/w)
new_h, new_w = int(h * scale), int(w * scale)
padded = torch.zeros(3, target_size, target_size)
pad_h = (target_size - new_h) // 2
pad_w = (target_size - new_w) // 2
padded[:, pad_h:pad_h+new_h, pad_w:pad_w+new_w] = F.interpolate(
image.unsqueeze(0), size=(new_h, new_w), mode='bilinear')[0]
return padded
渐进式训练策略:
注意力蒸馏(适用于小数据集):
python复制class DistillWrapper(nn.Module):
def __init__(self, teacher, student):
super().__init__()
self.teacher = teacher
self.student = student
def forward(self, x):
with torch.no_grad():
t_attn = self.teacher.get_attention_maps(x)
s_attn = self.student.get_attention_maps(x)
# 计算注意力蒸馏损失
loss_attn = F.mse_loss(s_attn, t_attn)
# 常规分类损失
logits = self.student(x)
loss_cls = F.cross_entropy(logits, labels)
return 0.7 * loss_cls + 0.3 * loss_attn
将ViT与DETR结合的关键修改点:
python复制class MultiScaleViT(nn.Module):
def __init__(self):
super().__init__()
self.stage1 = PatchEmbed(stride=4, out_dim=64)
self.stage2 = PatchEmbed(stride=8, out_dim=128)
self.stage3 = PatchEmbed(stride=16, out_dim=256)
def forward(self, x):
f1 = self.stage1(x) # 高分辨率浅层特征
f2 = self.stage2(x) # 中等分辨率特征
f3 = self.stage3(x) # 低分辨率深层特征
return [f1, f2, f3]
python复制class ObjectQueries(nn.Module):
def __init__(self, num_queries=100, dim=256):
super().__init__()
self.queries = nn.Parameter(torch.randn(num_queries, dim))
self.spatial_prior = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
def forward(self, features):
# features: 多尺度特征列表
spatial_weights = []
for feat in features:
if len(feat.shape) == 3: # [B, N, C]
h = int(feat.shape[1]**0.5)
feat = feat.view(feat.shape[0], h, h, -1).permute(0,3,1,2)
spatial_weights.append(self.spatial_prior(feat))
# 融合多尺度空间权重
combined_weights = F.interpolate(spatial_weights[-1], scale_factor=2, mode='bilinear')
for w in spatial_weights[:-1][::-1]:
combined_weights += F.interpolate(w, size=combined_weights.shape[-2:], mode='bilinear')
return self.queries.unsqueeze(0) * combined_weights.flatten(2).transpose(1,2)
python复制class EfficientAttention(nn.Module):
def __init__(self, dim, num_heads=8, window_size=7):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
def forward(self, x):
B, N, C = x.shape
H = W = int(N ** 0.5)
# 划分局部窗口
x = x.view(B, H, W, C)
x = x.permute(0, 3, 1, 2) # [B, C, H, W]
# 使用深度可分离卷积近似注意力
conv = nn.Sequential(
nn.Conv2d(C, C, kernel_size=window_size,
padding=window_size//2, groups=C),
nn.Conv2d(C, C, kernel_size=1)
)
return conv(x).permute(0, 2, 3, 1).view(B, N, C)
python复制def quantize_model(model, calib_data):
model.eval()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
# 特别处理LayerNorm和Softmax
for module in model.modules():
if isinstance(module, nn.LayerNorm):
module.qconfig = None
if isinstance(module, nn.Softmax):
module.qconfig = None
quant_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear, nn.Conv2d},
dtype=torch.qint8
)
# 校准
with torch.no_grad():
for data in calib_data:
quant_model(data[0])
return quant_model
针对PCB板缺陷检测的完整流程:
python复制train_transform = transforms.Compose([
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2)], p=0.8
),
transforms.RandomGrayscale(p=0.2),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
RandomCutout(max_size=32, p=0.5) # 模拟遮挡
])
python复制class MultiTaskHead(nn.Module):
def __init__(self, feat_dim, num_defect_types=6):
super().__init__()
# 缺陷分类头
self.classifier = nn.Sequential(
nn.Linear(feat_dim, feat_dim//2),
nn.GELU(),
nn.Linear(feat_dim//2, num_defect_types)
)
# 缺陷定位头
self.locator = nn.Sequential(
nn.Conv2d(feat_dim//16, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(256, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
# x: [B, N, C]
B, N, C = x.shape
H = W = int(N ** 0.5)
# 分类分支
cls_logits = self.classifier(x.mean(dim=1))
# 定位分支
loc_feat = x.transpose(1,2).view(B, C, H, W)
heatmap = self.locator(loc_feat[:, :C//16, :, :])
return cls_logits, heatmap
python复制def visualize_attention(image, model, layer_idx=6, head_idx=0):
# 注册hook获取注意力权重
attentions = []
def hook(module, input, output):
attentions.append(output[1]) # output: (output, attn_weights)
handle = model.blocks[layer_idx].attn.register_forward_hook(hook)
# 前向传播
model(image.unsqueeze(0))
handle.remove()
# 处理注意力权重
attn = attentions[0][0, head_idx] # [1, num_heads, N+1, N+1]
cls_attn = attn[0, 1:] # 取class token对其他patch的注意力
# 生成热力图
h = w = int(cls_attn.shape[0]**0.5)
heatmap = cls_attn.reshape(h, w).cpu().numpy()
# 叠加原图
img = image.permute(1,2,0).cpu().numpy()
img = (img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])) * 255
img = img.astype(np.uint8)
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = heatmap * 0.4 + img * 0.6
return superimposed_img
python复制class AnomalyScorer:
def __init__(self, model, train_features):
self.model = model
self.train_features = train_features # 正常样本特征库
self.pca = PCA(n_components=64)
self.pca.fit(F.normalize(torch.cat(train_features), p=2, dim=1))
def __call__(self, x):
with torch.no_grad():
feat = self.model(x) # 提取特征
feat = F.normalize(feat, p=2, dim=1)
pca_feat = self.pca.transform(feat.cpu())
# 计算与最近邻的距离
dists = torch.cdist(torch.tensor(pca_feat),
torch.tensor(self.pca.transform(
torch.cat(self.train_features).cpu()
)))
min_dist = dists.min(dim=1)[0]
# 获取注意力异常值
attn_entropy = self.compute_attention_entropy(x)
return 0.7 * min_dist + 0.3 * attn_entropy
def compute_attention_entropy(self, x):
attentions = []
def hook(module, input, output):
attentions.append(output[1])
handles = []
for blk in self.model.blocks[-3:]: # 最后三层
handles.append(blk.attn.register_forward_hook(hook))
self.model(x)
for h in handles:
h.remove()
# 计算注意力熵
entropy = 0
for attn in attentions:
attn = attn.mean(dim=1) # 平均多头注意力
entropy += (-(attn * torch.log(attn + 1e-9)).sum(dim=-1)).mean()
return entropy / len(attentions)