1. 时序预测新范式:FEDFormer架构全景解析
在电力负荷预测、气象分析和交通流量预测等领域,传统时序模型常面临长期依赖捕捉不足和计算复杂度高的问题。2022年提出的FEDFormer(傅里叶增强分解Transformer)通过频域增强和季节性分解机制,在保持线性计算复杂度的同时,显著提升了长期预测性能。本文将从算法设计思想到PyTorch实现,带您穿透数学公式直观理解这一创新架构。
1.1 核心创新点拆解
FEDFormer的突破性设计主要体现在三个维度:
- 频域注意力机制:在傅里叶空间进行低秩近似,将标准Transformer的O(N²)复杂度降至O(N)
- 混合季节-趋势分解:借鉴Prophet的思想,通过可学习滤波器分离时序信号的周期性和趋势成分
- 随机傅里叶特征:采用RFF将高维注意力计算投影到低维空间,实现计算效率与精度的平衡
关键洞察:传统Transformer在时序预测中的主要瓶颈在于注意力矩阵的稠密计算,而FEDFormer通过频域稀疏化成功突破了这一限制。
2. 频域增强注意力实现详解
2.1 傅里叶基投影原理
FEDFormer采用离散傅里叶变换(DFT)矩阵作为投影基,将输入序列x∈R^{N×d}转换到频域:
python复制import torch
def DFT_matrix(N):
n = torch.arange(N)
k = n.reshape((N, 1))
M = torch.exp(-2j * torch.pi * k * n / N)
return M / torch.sqrt(N)
dft_mat = DFT_matrix(seq_len) # [seq_len, seq_len]
2.2 低秩频域注意力
通过选择前k个低频分量实现注意力稀疏化:
python复制class FrequencyEnhancedAttention(nn.Module):
def __init__(self, d_model, k=64):
super().__init__()
self.query_proj = nn.Linear(d_model, k)
self.key_proj = nn.Linear(d_model, k)
self.value_proj = nn.Linear(d_model, d_model)
def forward(self, x):
# 投影到频域
Q_f = torch.fft.rfft(self.query_proj(x), dim=1)[..., :k]
K_f = torch.fft.rfft(self.key_proj(x), dim=1)[..., :k]
# 频域点积注意力
attn = torch.softmax(Q_f @ K_f.transpose(-1,-2), dim=-1)
return self.value_proj(x) @ attn
实测技巧:k值通常取序列长度的1/8~1/4,在ETTh1数据集上k=64时效果最佳
3. 混合分解模块工程实现
3.1 可学习滤波器设计
通过卷积核实现自适应季节-趋势分离:
python复制class LearnableDecomp(nn.Module):
def __init__(self, kernel_size=25):
super().__init__()
self.avg_pool = nn.AvgPool1d(kernel_size, stride=1, padding=kernel_size//2)
def forward(self, x):
trend = self.avg_pool(x.permute(0,2,1)).permute(0,2,1)
season = x - trend
return trend, season
3.2 多尺度分解架构
完整的前向传播流程包含三级分解:
mermaid复制graph TD
A[原始序列] --> B(第一级分解)
B --> C[趋势1]
B --> D[季节1]
D --> E(第二级分解)
E --> F[趋势2]
E --> G[季节2]
G --> H(第三级分解)
H --> I[趋势3]
H --> J[季节3]
J --> K[残差项]
4. 完整模型训练技巧
4.1 多阶段训练策略
- 第一阶段:固定频域注意力模块,仅训练分解模块(10-20个epoch)
- 第二阶段:联合微调全部参数,使用余弦退火学习率调度
- 第三阶段:冻结底层特征提取器,微调预测头
4.2 关键超参数配置
| 参数 | ETTh1建议值 | 气象数据集建议值 | 说明 |
|---|---|---|---|
| 学习率 | 3e-4 | 5e-4 | 使用AdamW优化器 |
| 批大小 | 32 | 64 | 根据GPU显存调整 |
| 频域维度k | 64 | 128 | 与序列长度正相关 |
| 分解级数 | 3 | 2 | 复杂数据用更多级数 |
| 丢弃率 | 0.1 | 0.2 | 防止频域过拟合 |
5. 实战常见问题排查
5.1 频域混叠现象
症状:预测结果出现异常高频振荡
解决方案:
- 检查DFT矩阵的归一化系数
- 在频域注意力后添加低通滤波层
- 适当减少选择的频域分量数量k
5.2 分解不充分
症状:趋势项仍包含周期性波动
调试步骤:
python复制# 可视化各分解层级输出
with torch.no_grad():
trend1, season1 = decomp1(x)
trend2, season2 = decomp2(season1)
plt.plot(trend1[0,:,0].cpu().numpy())
6. 扩展应用场景
6.1 多变量时序预测
通过特征维度的频域分析捕捉变量间关联:
python复制class MultivariateFEA(nn.Module):
def __init__(self, d_feature, d_model):
super().__init__()
self.feature_proj = nn.Linear(d_feature, d_model)
self.fea_attn = FrequencyEnhancedAttention(d_model)
def forward(self, x):
# x: [batch, seq_len, num_features]
x = self.feature_proj(x) # 统一到相同维度
return self.fea_attn(x)
6.2 缺失值处理
利用频域特性实现鲁棒预测:
- 对缺失窗口进行零填充
- 在频域计算掩码注意力权重
- 通过逆FFT重建完整序列
python复制def masked_fft(x, mask):
x_masked = x * mask
fft = torch.fft.rfft(x_masked, dim=1)
return fft / (mask.sum(dim=1, keepdim=True) + 1e-6)
7. 模型轻量化改进
7.1 频域维度压缩
通过PCA降低频域投影维度:
python复制from sklearn.decomposition import PCA
class CompressedFEA(nn.Module):
def __init__(self, original_k=64, compressed_k=32):
super().__init__()
self.pca = PCA(n_components=compressed_k)
# 在训练集上拟合PCA参数
self.compressed_k = compressed_k
def forward(self, x):
fft = torch.fft.rfft(x, dim=1)[..., :original_k]
fft_compressed = torch.from_numpy(
self.pca.transform(fft.cpu())).to(x.device)
return fft_compressed
7.2 量化部署方案
- 将频域矩阵转换为8位定点数
- 使用查表法加速三角函数计算
- 对分解模块采用动态量化
python复制quant_model = torch.quantization.quantize_dynamic(
model, {nn.Linear}, dtype=torch.qint8)
8. 不同场景下的调优建议
8.1 电力负荷预测
- 重点加强日/周周期模式的捕捉
- 建议配置:
python复制config = { 'n_levels': 3, 'k_freq': 128, 'decomp_kernel': [31, 15, 7] # 多尺度卷积核 }
8.2 交通流量预测
- 需处理突发性事件的影响
- 关键调整:
python复制nn.Dropout(0.3), # 更高丢弃率 loss_fn = nn.HuberLoss() # 鲁棒损失函数
9. 效果评估与对比
9.1 典型数据集表现
在ETTh1(电力)数据集上96步预测结果:
| 模型 | MSE | MAE | 训练时间 |
|---|---|---|---|
| Informer | 0.365 | 0.401 | 2.1h |
| Autoformer | 0.339 | 0.372 | 2.8h |
| FEDFormer | 0.307 | 0.342 | 1.9h |
9.2 显存占用对比
(batch_size=32, seq_len=96)
| 模型 | 显存占用 | 参数量 |
|---|---|---|
| Vanilla Transformer | 8.7GB | 23M |
| FEDFormer-base | 3.2GB | 15M |
| FEDFormer-small | 1.8GB | 7M |
10. 进阶研究方向
10.1 时频自适应选择
动态调整频域分量选择策略:
python复制class AdaptiveSelector(nn.Module):
def __init__(self, d_model):
super().__init__()
self.selector = nn.Linear(d_model, k)
def forward(self, x):
# x: [batch, seq_len, d_model]
scores = self.selector(x.mean(dim=1)) # [batch, k]
return torch.topk(scores, dim=-1, k=dynamic_k)
10.2 多模态融合
结合外部特征(如天气数据):
- 对数值特征进行傅里叶嵌入
- 分类特征采用注意力交互
- 在频域空间进行特征融合
python复制class MultiModalFusion(nn.Module):
def __init__(self, d_num, d_cat):
super().__init__()
self.num_proj = nn.Linear(d_num, d_model)
self.cat_proj = nn.Embedding(d_cat, d_model)
def forward(self, num_x, cat_x):
num_fft = torch.fft.rfft(self.num_proj(num_x), dim=1)
cat_feat = self.cat_proj(cat_x)
return num_fft * cat_feat.unsqueeze(1)