1. 联合嵌入预测架构(JEPA)与变分推断的融合探索
在自监督学习领域,联合嵌入预测架构(Joint-Embedding Predictive Architecture, JEPA)近年来引起了广泛关注。这种架构由Yann LeCun在2022年提出,旨在构建能够理解并预测世界的"世界模型"。与传统生成模型不同,JEPA完全在隐空间工作:它使用上下文信号x的嵌入来预测相关目标视图y的嵌入,而不是直接在原始观测空间中进行重构。
1.1 JEPA的核心设计原理
标准JEPA实现包含三个关键组件:
- 上下文编码器f^ctx:将上下文观测x映射到潜在表示s_x
- 目标编码器f^trg:将目标观测y映射到潜在表示s_y
- 预测器g:从上下文表示s_x和可选辅助变量z生成目标表示ŝ_y
训练过程中,模型最小化预测表示ŝ_y与真实目标表示s_y之间的距离,同时对目标编码器应用停止梯度(stop-gradient)操作,防止模型学习无意义的预测特征。
这种设计虽然效果显著,但也面临两个主要挑战:
- 表示坍缩风险:编码器容易输出恒定向量,需要依赖特定技术(如辅助损失、指数移动平均EMA)来保证表示多样性
- 被定性为非生成方法:与基于似然的自监督学习和概率隐变量建模存在割裂
1.2 从确定性到概率性的转变
本文提出的Var-JEPA框架为JEPA提供了全新的概率解释。研究发现,JEPA与概率生成建模的割裂主要是表述层面的,而非结构层面的。标准JEPA的设计——耦合编码器与上下文到目标的预测器——与对一类耦合隐变量模型做变分推断得到的变分后验、学习到的条件先验高度吻合。
具体而言,当我们将JEPA中的预测嵌入步骤解释为耦合变分自编码器(VAE)的变分后验时,可以自然地推导出一个概率潜在变量框架。在这个框架下:
- JEPA的确定性编码器和预测器被替换为条件分布
- 预测路径被解释为学到的潜在空间条件先验
- 建立了一个清晰的关于上下文、目标和辅助潜在变量的生成过程
这种形式化将JEPA风格的预测学习与重建和条件生成统一起来,同时通过潜在正则化提供了严格的防坍缩机制。
2. Var-JEPA的理论框架与实现
2.1 生成模型的结构设计
Var-JEPA处理上下文观测x∈R^D和目标观测y∈R^D,学习潜在表示:
- s_x∈R^d:上下文潜在表示
- s_y∈R^d:目标潜在表示
- z∈R^{d_z}:辅助预测变量,用于捕捉s_x无法解释的s_y中的变异性
生成过程的有向无环图(DAG)遵循以下结构:
- x → s_x(JEPA推断)
- y → s_y(JEPA推断)
- z连接到s_y
- s_x → s_y
与标准JEPA的单向映射不同,Var-JEPA的变分框架要求双向关系,以同时建模生成过程和推断过程。对于上下文和目标,模型分别学习两个方向:
- 编码(从观测到潜在)
- 重建(从潜在回观测)
单一证据下界(ELBO)目标将这些方向绑定在一起,并联合训练编码器和解码器。
2.2 概率模型参数化
生成模型中的分布被实现为高斯分布。潜在变量的先验为标准高斯分布,而条件分布由神经网络参数化:
-
先验分布:
- p(s_x) = N(s_x; 0, I)
- p(z) = N(z; 0, I)
-
条件分布:
- p_θ(x|s_x) = N(x; U^x_θ(s_x), σ_x^2 I)
- p_θ(y|s_y) = N(y; U^y_θ(s_y), σ_y^2 I)
- p_θ(s_y|s_x,z) = N(s_y; μ^s_y_θ(s_x,z), Σ^s_y_θ(s_x,z))
其中,U^x_θ和U^y_θ是从潜在表示重构观测的解码器网络,μ^s_y_θ和Σ^s_y_θ是由神经网络计算的预测网络输出。
2.3 变分后验设计
由于直接优化边际对数似然log p_θ(x,y)不可行,Var-JEPA引入了易处理的变分后验q_ϕ(s_x,z,s_y|x,y)作为真实后验的近似。变分后验被因子分解为:
q_ϕ(s_x,z,s_y|x,y) = q_ϕ(s_x|x) · q_ϕ(z|s_x) · q_ϕ(s_y|s_x,z,y)
这种因子分解设计确保了:
- 上下文潜在q_ϕ(s_x|x)仅依赖于上下文观测x,使上下文表示能独立于目标信息学习
- 辅助潜在z仅依赖于上下文表示s_x,防止训练期间目标信息泄漏
- 目标后验q_ϕ(s_y|s_x,z,y)同时依赖s_x、z和目标观测y,通过重建项正则化s_y的学习
变分后验的每个组成部分都被参数化为具有可学习均值和协方差的高斯分布:
- q_ϕ(s_x|x) = N(s_x; μ^s_x_ϕ(x), Σ^s_x_ϕ(x))
- q_ϕ(z|s_x) = N(z; μ^z_ϕ(s_x), Σ^z_ϕ(s_x))
- q_ϕ(s_y|s_x,z,y) = N(s_y; μ^s_y_ϕ(s_x,z,y), Σ^s_y_ϕ(s_x,z,y))
其中,推断网络μ^·_ϕ和Σ^·_ϕ实现为神经网络,根据各自输入输出分布参数。
3. Var-JEPA的训练与优化
3.1 证据下界(ELBO)推导
通过Jensen不等式,我们可以推导出边际对数似然的易计算变分下界——证据下界(ELBO):
log p_θ(x,y) ≥ E_{q_ϕ(s_x,z,s_y|x,y)}[log p_θ(x,y,s_x,z,s_y) - log q_ϕ(s_x,z,s_y|x,y)] = L_ELBO
代入p_θ和q_ϕ的因子分解后,ELBO可以展开为:
L_ELBO = E[log p_θ(x|s_x)] + E[log p_θ(y|s_y)] (重建项)
- KL(q_ϕ(s_x|x) || p(s_x)) - KL(q_ϕ(z|s_x) || p(z)) (正则化项)
- E[KL(q_ϕ(s_y|s_x,z,y) || p_θ(s_y|s_x,z))] (条件先验匹配项)
这个ELBO目标有几个关键特点:
- 重建项鼓励学到的表示能够很好地解释观测数据
- KL正则化项将s_x和z的分布推向标准正态先验
- 条件先验匹配项使目标后验接近学到的条件先验
3.2 重参数化技巧实现
为了通过随机潜在变量进行反向传播,Var-JEPA应用了重参数化技巧。潜在变量采样使用以下形式:
s_x = μ^s_x_ϕ(x) + [Σ^s_x_ϕ(x)]^{1/2}·ε_s_x
z = μ^z_ϕ(s_x) + [Σ^z_ϕ(s_x)]^{1/2}·ε_z
s_y = μ^s_y_ϕ(s_x,z,y) + [Σ^s_y_ϕ(s_x,z,y)]^{1/2}·ε_s_y
其中ε_·为独立标准高斯噪声向量,[Σ(·)_ϕ(·)]^{1/2}表示协方差的矩阵平方根。在实践中,我们在每次前向传递中对每个潜在变量采样一次来估计梯度。
3.3 与标准JEPA的对比分析
图1展示了标准JEPA与Var-JEPA之间的关系。两种方法共享相同的核心预测结构,包括:
- 上下文(x)和目标(y)观测
- 它们的潜在表示s_x和s_y
- 辅助潜在变量z
关键区别在于:
- 标准JEPA依赖推断和预测网络,需要代理损失来防止表示坍缩
- Var-JEPA通过添加生成网络(解码器)扩展JEPA,可以用统一的变分目标训练模型
- Var-JEPA的ELBO自然地防止坍缩,无需特殊设计的正则化机制
4. Var-JEPA的理论性质与解释
4.1 防坍缩机制分析
Var-JEPA通过其变分正则化项提供了严格的防坍缩机制。对于具有固定标准正态先验(s_x和z)的潜在变量,ELBO包含每个样本的KL项:
E KL(q_ϕ(·) || N(0,I))
这些项可以分解为聚合后验失配项和信息瓶颈项:
E_x KL(q_ϕ(s_x|x) || N(0,I)) = KL(q_ϕ(s_x) || N(0,I)) + I_{q_ϕ}(x; s_x)
这种分解揭示了Var-JEPA如何:
- 通过KL(q_ϕ(s_x) || N(0,I))项鼓励聚合后验接近各向同性高斯分布
- 通过I_{q_ϕ}(x; s_x)项控制输入与潜在变量之间的互信息,防止过拟合
相比之下,目标潜在变量的ELBO项为:
KL(q_ϕ(s_y|s_x,z,y) || p_θ(s_y|s_x,z))
这正则化到学到的条件先验而非N(0,I),是理论上更合适的目标。
4.2 与LeJEPA的关系
LeJEPA(Balestriero & LeCun, 2025)将各向同性高斯嵌入动机为在广泛探针家族下对下游预测的最小最大最优。它使用SIGReg(通过随机一维投影将聚合嵌入分布匹配到各向同性高斯)来强制执行此分布结构。
Var-JEPA通过其变分正则化项与这一思路相关。模拟研究表明:
- ELBO中的KL散度项为s_x实现了与显式聚合分布正则化(SIGReg)相当的分布性质
- 每个样本KL正则化到固定先验自然强制聚合分布各向同性,无需额外正则化机制
- 对于s_y,ELBO正则化到学到的条件先验是更合适的目标
5. Var-T-JEPA:表格数据的实际应用
5.1 模型设计与实现
我们将Var-JEPA框架实例化为表格数据的实用实现——Var-T-JEPA。该模型结合了:
- 特征级掩码策略
- 统一的变分目标(式10)
- 对异构表格数据的专门处理
具体实现特点包括:
- 将异构数值和类别特征标记化为transformer序列
- 推断高斯潜在嵌入s_x和s_y
- 通过耦合重建-预测目标训练潜在空间预测器
- 产生确定性嵌入(通过后验均值)和来自学到潜在分布的每个样本不确定性估计
5.2 实验设置与评估
我们在多个真实世界表格数据集上评估Var-T-JEPA:
- Adult (AD)
- Covertype (CO)
- Electricity (EL)
- Credit Card (CC)
- Bank Marketing (BM)
以及半合成数据集MNIST和完全合成的模拟数据集(SIM)。评估方法包括:
- 与强原始特征基线的比较
- 与确定性T-JEPA基线的比较
- 使用多种预测器架构(MLP、DCNv2、ResNet等)的下游评估
- 通过不确定性的选择性评估(丢弃不确定性最高的样本后的准确率)
5.3 实验结果分析
实验结果显示:
- 在真实世界表格数据集上,Var-T-JEPA产生了具有竞争力的嵌入
- 选择性评估显示出清晰的覆盖-准确率权衡:丢弃最不确定的测试样本可提升性能
- 确定性T-JEPA基线在某些数据集上出现表示坍缩,导致下游性能下降
- 在合成数据集上,Var-T-JEPA的不确定性信号与底层模拟结构一致
关键发现包括:
- 潜在不确定性不仅对放弃有用,而且与受控设置中已知的不确定性信号对齐
- Var-T-JEPA在不确定性估计和表示质量方面都优于确定性基线
- 统一变分目标成功防止了表示坍缩,无需额外的正则化启发式
6. 讨论与未来方向
6.1 方法论贡献
本文的主要贡献在于:
- 建立了JEPA与变分推断的新形式化联系,将其重新解释为确定性隐变量模型
- 提出了Var-JEPA框架,通过ELBO目标桥接预测式联合嵌入学习与生成式建模
- 证明了ELBO可以自然避免表示坍缩,无需JEPA常用的替代损失
- 开发了针对异构表格数据的实用实现Var-T-JEPA
- 通过实验验证了方法在表示学习和下游任务上的有效性
6.2 实际应用价值
Var-JEPA框架为实际应用带来几个优势:
- 严格的不确定性量化:通过后验协方差提供可靠的置信度估计
- 更稳定的训练:ELBO自然防止表示坍缩,减少对启发式正则化的依赖
- 灵活的架构设计:可以适应不同类型的数据和任务
- 选择性预测能力:基于不确定性估计做出更可靠的决策
6.3 未来研究方向
基于当前工作,有几个有前景的未来方向:
- 扩展到视觉和视频领域:将Var-JEPA框架应用于图像和视频数据
- 处理缺失目标观测:开发在测试时目标观测缺失情况下的条件生成方法
- 探索更复杂的先验分布:超越高斯假设,研究更丰富的潜在空间结构
- 与其他自监督方法的整合:探索Var-JEPA与其他自监督学习范式的结合
Var-JEPA代表了自监督学习领域的一个重要进展,它弥合了预测式和生成式方法之间的鸿沟,为构建更强大、更可靠的世界模型提供了新的理论基础和实践工具。