在机器学习领域,如何量化两个概率分布之间的差异是一个基础而关键的问题。无论是生成模型的训练、域适配任务,还是简单的分布比较,我们都需要可靠的度量工具。这些工具可以分为三大类:
每种方法都有其独特的性质和应用场景。例如在生成对抗网络(GAN)中,Wasserstein距离因其良好的梯度特性而被广泛采用;而在双样本检验问题中,MMD因其计算简便和理论保证常成为首选。
实际应用中,选择哪种度量取决于:① 计算复杂度要求 ② 对分布支撑集不匹配的敏感度 ③ 梯度传播的需求 ④ 样本效率
KL(Kullback-Leibler)散度定义为:
code复制KL(P||Q) = ∫ p(x) log(p(x)/q(x)) dx
其核心特性包括:
这些特性导致KL散度在实际应用中存在明显局限。例如在生成模型中,当生成分布Q的支撑集小于真实分布P时,KL散度会发散,使得训练不稳定。
JS(Jensen-Shannon)散度是对称化的KL散度:
code复制JS(P,Q) = 0.5*KL(P||M) + 0.5*KL(Q||M), 其中M=0.5*(P+Q)
虽然JS散度解决了对称性问题,但仍然存在:
这些缺陷促使研究者寻找更鲁棒的分布距离度量。
Wasserstein距离源于最优传输理论,考虑将分布P"搬运"到分布Q的最小成本。对于离散情况,可以表示为:
code复制W_p(P,Q) = (inf_γ∈Γ(P,Q) ∫||x-y||^p dγ(x,y))^(1/p)
其中Γ(P,Q)是所有联合分布,其边缘分布分别为P和Q。
code复制W_1(P,Q) = sup_{||f||_L≤1} |E_P[f] - E_Q[f]|
原始Wasserstein距离的计算涉及线性规划,复杂度为O(n^3)。对于大规模数据,这显然不可行。解决方法包括:
在Kantorovich问题中加入熵项:
code复制OT_ε = min_γ∈Γ(P,Q)
其中H(γ) = -∫γ(x,y)logγ(x,y)dxdy
ε控制正则化强度:
通过交替归一化实现高效计算:
code复制输入:成本矩阵C,分布a,b,正则化参数ε
初始化:K = exp(-C/ε), u = ones(n), v = ones(m)
重复:
v = b / (K^T u)
u = a / (K v)
直到收敛
输出:P = diag(u) K diag(v)
MMD通过再生核希尔伯特空间(RKHS)中的均值嵌入来比较分布:
code复制MMD^2(P,Q) = ||μ_P - μ_Q||_H^2
其中μ_P = E_{x∼P}[φ(x)]是核φ下的均值嵌入。
无偏估计形式:
code复制MMD^2 = 1/(n(n-1))∑_{i≠j}k(x_i,x_j)
+ 1/(m(m-1))∑_{i≠j}k(y_i,y_j)
- 2/(nm)∑_{i,j}k(x_i,y_j)
有向Chamfer距离:
code复制d_{CD}(X,Y) = 1/|X|∑_{x∈X} min_{y∈Y} ||x-y||^2
特点:
定义为:
code复制d_H(X,Y) = max{sup_{x∈X} inf_{y∈Y} ||x-y||,
sup_{y∈Y} inf_{x∈X} ||x-y||}
反映的是两个点集之间的最大不匹配程度。
3D重建质量评估:
训练技巧:
常用指标:
推荐组合:
python复制def sinkhorn(C, a, b, eps, max_iter=100):
# C: cost matrix (n,m)
# a: source distribution (n,)
# b: target distribution (m,)
# eps: regularization strength
K = torch.exp(-C/eps)
u = torch.ones_like(a)
v = torch.ones_like(b)
for _ in range(max_iter):
v = b / (K.T @ u)
u = a / (K @ v)
P = torch.diag(u) @ K @ torch.diag(v)
return P
关键优化:
python复制def mmd_rbf(X, Y, sigma):
XX = torch.cdist(X, X) ** 2
YY = torch.cdist(Y, Y) ** 2
XY = torch.cdist(X, Y) ** 2
K_XX = torch.exp(-XX / (2*sigma**2))
K_YY = torch.exp(-YY / (2*sigma**2))
K_XY = torch.exp(-XY / (2*sigma**2))
term1 = K_XX.mean() - K_XX.diag().mean()
term2 = K_YY.mean() - K_YY.diag().mean()
term3 = 2 * K_XY.mean()
return term1 + term2 - term3
处理质量不守恒的情况:
code复制OT_ε^u = min_γ
其中D_φ是φ-散度,控制质量变化惩罚
应用场景:
比较结构相似性:
code复制GW(P,Q) = min_γ
其中L是结构差异函数
特别适用于:
指标选择指南:
参数调优经验:
常见陷阱:
在实际项目中,我通常会同时计算多个互补的度量,以全面评估模型性能。例如在3D点云生成任务中,同时监控Chamfer距离、EMD和法向一致性,可以分别捕捉局部精度、全局匹配和表面质量的不同方面。