万象(WanVideo)是一套开源的视频生成模型系列,其核心创新在于提出了全新的时空变分自编码器(VAE)架构和可扩展的预训练策略。这个14B参数规模的模型在数十亿图片和视频数据上进行了训练,展现出明显的scaling law特性。模型支持多种下游任务,包括图生视频(image-to-video)和指令引导的视频编辑等,特别值得一提的是它原生支持中文输入。
在实际部署中,1.3B版本的模型仅需8.19GB的VRAM即可运行,这使得它在消费级GPU上也能有不错的表现。本文将重点解析其核心组件DiT(Diffusion Transformer)的模型架构和前向计算过程。
提示:DiT作为扩散模型的核心组件,其设计直接影响了生成质量和计算效率。万象的DiT实现特别考虑了大规模训练和推理的优化,这在后续章节会详细展开。
文本输入首先经过UMT5编码器处理,输出形状固定为[B, 512, 4096],其中B是batch size,512是序列长度,4096是文本嵌入维度。具体实现如下:
python复制text_emb = get_umt5_embedding(
checkpoint_path=args.text_encoder_path,
prompts=args.prompt
).to(dtype=torch.bfloat16).cuda()
这种设计有几个关键考虑:
对于图生视频任务,需要将首帧图像编码为潜空间表示。处理流程如下:
python复制# 将单帧扩展为伪视频张量
frames_to_encode = torch.cat([
image_tensor.unsqueeze(2), # [B,3,H,W]->[B,3,1,H,W]
torch.zeros(1,3,F-1,h,w, device=image_tensor.device)
], dim=2) # -> [B,3,F,H,W]
# VAE编码
encoded_latents = tokenizer.encode(frames_to_encode) # -> [B,C,T,H,W]
编码过程中的关键参数:
将文本嵌入和图像潜表示组合成模型的条件输入:
python复制# 创建掩码并标记条件帧
msk = torch.zeros(1,4,lat_t,lat_h,lat_w, device=tensor_kwargs["device"])
msk[:,:,0,:,:] = 1.0 # 高亮第一帧
# 拼接掩码和潜向量
y = torch.cat([msk, encoded_latents], dim=1) # [1,4+C,T,H,W]
这种设计使得模型能明确区分条件帧和需要生成的帧,在实际应用中,调整掩码模式可以实现不同的生成控制效果。
输入潜向量的形状为[B,C,T,H,W],默认C=16。与条件y拼接后得到36通道的输入:
python复制x = torch.cat([x_B_C_T_H_W, y_B_C_T_H_W], dim=1) # [B,36,T,H,W]
分片处理采用patch_size=(1,2,2),将输入转换为序列形式:
python复制# 分片变换
x = rearrange(x, "b c (t kt) (h kh) (w kw) -> b (t h w) (c kt kh kw)",
kt=1, kh=2, kw=2)
# 线性投影
x = self.patch_embedding(x) # [B,L,d_in] -> [B,L,d=5120]
这种时空分片策略平衡了计算效率和局部性保留,1×2×2的patch大小在实践中被证明对视频数据特别有效。
时间步编码采用经典的sinusoidal位置编码加MLP的方案:
python复制# 1D正弦位置编码
t_emb = sinusoidal_embedding_1d(self.freq_dim, t_B) # [B,] -> [B,256]
# 两层MLP投影
e_B_D = nn.Sequential(
nn.Linear(256, 5120),
nn.SiLU(),
nn.Linear(5120, 5120)
)(t_emb) # [B,5120]
# 最终投影为6个调制参数
e0_B_6_D = self.time_projection(e_B_D).unflatten(1, (6, 5120)) # [B,6,5120]
6个调制参数分别用于控制:
万象采用创新的3D RoPE(Rotary Position Embedding)来处理视频数据的三维结构:
python复制class VideoRopePosition3DEmb:
def __init__(self, head_dim, len_h=128, len_w=128, len_t=32):
# 划分头部维度给时空三个方向
d_h = d_w = (head_dim // 6) * 2
d_t = head_dim - d_h - d_w
# 生成各方向频率
self.freqs_h = self._get_freqs(len_h, d_h//2)
self.freqs_w = self._get_freqs(len_w, d_w//2)
self.freqs_t = self._get_freqs(len_t, d_t//2)
def generate_embeddings(self, shape):
B,T,H,W,D = shape
# 组合三维频率
freqs = torch.cat([
repeat(self.freqs_t, "t d -> t h w d", h=H, w=W),
repeat(self.freqs_h, "h d -> t h w d", t=T, w=W),
repeat(self.freqs_w, "w d -> t h w d", t=T, h=H),
], dim=-1) # [T,H,W,D/2]
return freqs
这种设计使得位置编码能够同时捕获时空关系,相比传统的1D位置编码更适合视频数据。
万象的注意力实现会根据GPU架构自动选择最优后端:
python复制def attention(q, k, v, compute_cap, dtype):
if compute_cap == 90 and FLASH_ATTN_3_AVAILABLE:
return flash_attn_3(q, k, v) # H100等SM90架构
elif compute_cap in [80, 86, 89]:
return flash_attn_2(q, k, v) # A100/RTX40等
else:
return xformers_attention(q, k, v) # 通用后备方案
支持的硬件架构包括:
为支持大规模训练,实现了高效的序列并行方案:
python复制class DistributedAttention(nn.Module):
def forward(self, query, key, value):
if self.pg is None:
return self.local_attn(query, key, value)
# 序列并行三阶段
# 1. 从"局部序列完整头"转为"完整序列局部头"
q, k, v = _SeqAllToAllQKV.apply(
self.pg, query, key, value,
self.pg.size(), self.stream, True)
# 2. 本地注意力计算
context = self.local_attn(q, k, v)
# 3. 转回"局部序列完整头"
output = _SeqAllToAll.apply(
self.pg, context, False)
return output
这种设计使得注意力计算可以分布在多个GPU上,显著提升了长序列处理能力。
每个DiT块都包含时间步调制的FFN:
python复制class WanAttention(nn.Module):
def forward(self, x, e):
# e包含6个调制参数
e_shift1, e_scale1, e_gate1, e_shift2, e_scale2, e_gate2 = e.chunk(6, dim=1)
# 调制自注意力
x_attn = self.self_attn(
(norm1(x) * (1 + e_scale1) + e_shift1),
freqs
)
x = x + x_attn * e_gate1
# 调制FFN
x_ffn = self.ffn(
(norm2(x) * (1 + e_scale2) + e_shift2)
)
x = x + x_ffn * e_gate2
return x
调制机制允许模型根据时间步动态调整各层的行为,这在扩散模型中尤为重要。
对于文本条件生成,DiT块中还集成了交叉注意力:
python复制class WanCrossAttention(WanSelfAttention):
def forward(self, x, context):
q = self.norm_q(self.q(x)) # 来自潜变量
k = self.norm_k(self.k(context)) # 来自文本嵌入
v = self.v(context)
return self.attn_op(q, k, v)
这种设计使得文本条件能够直接影响每一层的特征表示,增强了模型对文本指令的响应能力。
完整的前向传播包含以下步骤:
当启用序列并行时,关键处理流程如下:
python复制def forward(self, x, timesteps, text_emb, y=None):
if self.cp_enabled:
x = broadcast(x, self.cp_group)
# 分片处理
if self.cp_enabled:
x = split_inputs_cp(x, seq_dim=1, cp_group=self.cp_group)
# DiT块处理
for block in self.blocks:
x = block(x, e, freqs, context)
# 结果聚合
if self.cp_enabled:
x = cat_outputs_cp(x, seq_dim=1, cp_group=self.cp_group)
return x
这种设计使得模型可以灵活地在单卡和多卡模式下运行,无需修改核心逻辑。
万象使用Triton实现了高效的旋转位置编码内核:
python复制@triton.autotune(configs=[
triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2),
triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8),
], key=["head_size", "interleaved"])
@triton.jit
def _rotary_embedding_kernel(
out_ptr, x_ptr, cos_ptr, sin_ptr,
n, d, s, stride_x_row, stride_cos_row, stride_sin_row,
BLOCK_HS_HALF: tl.constexpr
):
# 每个线程处理一个位置
row_idx = tl.program_id(0)
token_idx = (row_idx // n) % s
# 加载数据
x1 = tl.load(x_ptr + 2*offset)
x2 = tl.load(x_ptr + 2*offset + 1)
cos = tl.load(cos_ptr + offset)
sin = tl.load(sin_ptr + offset)
# 应用旋转
o1 = x1 * cos - x2 * sin
o2 = x1 * sin + x2 * cos
# 写回结果
tl.store(out_ptr + 2*offset, o1)
tl.store(out_ptr + 2*offset + 1, o2)
这种实现相比纯PyTorch版本可获得3-5倍的加速。
模型广泛使用混合精度训练技术:
python复制with amp.autocast("cuda", dtype=torch.float32):
e = self.time_embedding(t_emb.float())
x = x + y * e_gate.type_as(x)
关键策略包括:
根据硬件条件选择合适配置:
显存不足:
生成质量差:
训练不稳定: