生成对抗网络(Generative Adversarial Networks)的核心思想可以用一个简单的比喻来理解:想象一个造假币的罪犯(Generator)和一个经验丰富的警察(Discriminator)之间的猫鼠游戏。造假者不断改进假币质量试图骗过警察,而警察也在不断提升鉴别能力。这种对抗过程最终会促使造假者生产出几乎无法辨别的假币。
GAN由两个深度神经网络组成:
二者的损失函数设计体现了对抗本质:
python复制# 判别器损失 = 真实样本判断误差 + 生成样本判断误差
d_loss = -torch.mean(torch.log(D(real)) + torch.log(1 - D(G(z))))
# 生成器损失 = 判别器对生成样本的误判程度
g_loss = -torch.mean(torch.log(D(G(z))))
关键提示:实际实现时更常用BCELossWithLogits,数值稳定性更好
理想情况下,训练过程会经历三个阶段:
实际训练中常见的问题是判别器过早占据优势(准确率>85%),这时需要通过以下方法调节:
硬件建议配置:
CelebA数据集预处理流程:
bash复制# 官方数据集解压后执行
python preprocess.py \
--input_dir ./raw_images \
--output_dir ./processed \
--size 64 \ # 统一缩放尺寸
--normalize \ # 归一化到[-1,1]
--split 0.8 # 训练验证集分割
DCGAN的标准结构规范:
python复制class Generator(nn.Module):
def __init__(self, latent_dim=100):
super().__init__()
self.main = nn.Sequential(
# 输入: latent_dim维噪声
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# 上采样路径...
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
nn.Tanh() # 输出[-1,1]范围图像
)
python复制class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# 输入3通道图像
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 下采样路径...
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid() # 输出真伪概率
)
工程细节:使用spectral_norm()包装卷积层可显著提升训练稳定性
建议记录的关键指标:
python复制# 在训练循环中添加监控
metrics = {
'g_loss': [],
'd_loss': [],
'd_real_acc': [], # 真实样本判别准确率
'd_fake_acc': [], # 生成样本判别准确率
'fid_score': [] # 每10epoch计算一次
}
可视化工具推荐:
python复制def save_sample_grid(epoch):
with torch.no_grad():
z = torch.randn(16, latent_dim, 1, 1, device=device)
samples = generator(z)
save_image(samples, f"samples/epoch_{epoch}.png",
nrow=4, normalize=True)
模式崩溃(Mode Collapse)的典型表现:
解决方案对比表:
| 方法 | 实现方式 | 适用场景 | 效果评估 |
|---|---|---|---|
| Mini-batch Discrimination | 在判别器最后添加特征统计层 | 轻微模式崩溃 | +15%多样性 |
| Unrolled GAN | 展开判别器k步优化 | 周期性崩溃 | 训练速度↓30% |
| PacGAN | 打包输入样本 | 严重崩溃 | 需要修改架构 |
除常见的FID外,推荐组合使用:
Inception Score (IS)
python复制# 使用预训练Inception_v3
preds = inception_model(gen_imgs)
kl_div = preds * (torch.log(preds) - torch.log(torch.mean(preds, 0)))
is_score = torch.exp(torch.mean(kl_div.sum(1)))
人工评估体系设计
多GPU训练配置示例:
yaml复制# config.yaml
training:
batch_size: 256
nodes: 4
gpus_per_node: 8
sync_bn: True # 使用同步BN
gradient_accumulation: 2
optimization:
lr:
generator: 0.0001
discriminator: 0.00005
betas: [0.5, 0.999]
scheduler:
type: cosine
warmup_epochs: 10
性能优化技巧:
python复制class GenerationServer:
def __init__(self):
self.pool = ThreadPoolExecutor(max_workers=4)
self.request_queue = Queue()
async def generate(self, z):
future = self.pool.submit(self._generate, z)
return await asyncio.wrap_future(future)
def _generate(self, z):
with torch.no_grad():
return generator(z)
最新研究方向的性能对比:
| 架构 | 训练稳定性 | 生成质量 | 计算成本 |
|---|---|---|---|
| StyleGAN3 | ★★★★☆ | ★★★★★ | 高 |
| Diffusion+GAN | ★★★★☆ | ★★★★☆ | 极高 |
| Lightweight GAN | ★★★☆☆ | ★★★☆☆ | 低 |
在电商图片生成项目中遇到的典型问题:
解决方案:
经过6个月的迭代优化,我们的生成系统达到了: