遥感图像中的河流分割一直是环境监测、灾害预警和水资源管理的重要技术手段。传统方法依赖人工解译或简单阈值分割,效率低下且精度有限。TransUNet作为结合Transformer与CNN优势的混合架构,在医学图像分割领域已证明其优越性,而将其迁移到遥感场景需要解决一系列特殊挑战。
这个开源项目实现了基于PyTorch的TransUNet遥感河流分割方案,主要解决三个核心问题:
原始TransUNet的ViT编码器直接应用于遥感图像会面临两个问题:
本项目采用以下改进方案:
python复制class HybridEncoder(nn.Module):
def __init__(self, img_size=256, in_chans=3, patch_size=16):
super().__init__()
# 第一阶段:CNN特征提取
self.cnn_backbone = ResNet34(pretrained=True)
# 第二阶段:Patch Embedding
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=768
)
# 位置编码采用可学习参数
self.pos_embed = nn.Parameter(torch.zeros(1, 196, 768))
关键改进点:
遥感河流分割需要处理从1:500到1:10000不同比例尺的图像,传统U-Net的单一上采样路径难以应对。我们设计了多分支特征融合机制:
python复制class MultiScaleFusion(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv3x3 = nn.Conv2d(channels, channels, 3, padding=1)
self.conv5x5 = nn.Conv2d(channels, channels, 5, padding=2)
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels//8, 1),
nn.ReLU(),
nn.Conv2d(channels//8, channels, 1),
nn.Sigmoid()
)
def forward(self, x):
branch1 = self.conv3x3(x)
branch2 = self.conv5x5(x)
fused = branch1 + branch2
att = self.attention(fused)
return x * att
不同于自然图像,遥感数据需要特殊预处理:
python复制def atmospheric_correction(image, dark_object_value=50):
return image - dark_object_value
常规的旋转/翻转增强对线性地物效果有限,我们设计了几何形变增强:
python复制class RiverAugmentation:
def __init__(self):
self.thin_plate_spline = TPS_Sampler(
grid_size=(5,5),
target_std=0.1
)
def __call__(self, img, mask):
# 随机生成形变场
displacement = torch.randn(2, 5, 5) * 0.1
warped_img = self.thin_plate_spline(img, displacement)
warped_mask = self.thin_plate_spline(mask, displacement)
return warped_img, warped_mask
河流分割的类别不平衡问题严重(河流像素占比通常<5%),我们组合三种损失:
python复制loss = 0.4*DiceLoss() + 0.3*FocalLoss(gamma=2) + 0.3*BoundaryLoss()
采用warmup+cosine衰减策略:
python复制scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
[
LinearWarmupLR(warmup_steps=500),
CosineAnnealingLR(T_max=10000)
]
)
实测在batch_size=16时,初始lr=3e-4效果最佳。
原始TransUNet参数量达85M,我们通过以下方式压缩:
导出时需特别注意自定义算子的处理:
python复制torch.onnx.export(
model,
dummy_input,
'model.onnx',
opset_version=12,
custom_opsets={'CustomOp': 1},
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}}
)
使用TensorRT加速后,1080Ti上的推理速度从45ms降至12ms。
2022年某流域洪水期间,使用该模型处理哨兵2号影像:
通过时序分割结果计算河道宽度变化:
python复制def detect_sand_mining(masks_series, threshold=0.15):
width_changes = []
for i in range(1, len(masks_series)):
delta = (masks_series[i-1] - masks_series[i]).sum()
width_changes.append(delta)
anomalies = np.where(np.diff(width_changes) > threshold)[0]
return anomalies
现象:宽度<3像素的支流分割不连续
解决方法:
python复制kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3))
refined = cv2.morphologyEx(pred, cv2.MORPH_CLOSE, kernel)
现象:建筑物阴影被错误分割
改进方案:
python复制class SpectralLoss(nn.Module):
def __init__(self, band_weights=[0.5,0.3,0.2]):
super().__init__()
self.weights = torch.tensor(band_weights)
def forward(self, pred, target, image):
spec_diff = (image*pred - image*target).mean(dim=(2,3))
return (spec_diff * self.weights).sum()
结合LSTM模块处理时序数据:
python复制class ChangeDetector(nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
self.lstm = nn.LSTM(768, 256, bidirectional=True)
self.decoder = nn.Conv2d(512, 2, 1)
def forward(self, x_series):
features = [self.encoder(x) for x in x_series]
temporal, _ = self.lstm(torch.stack(features))
return self.decoder(temporal[-1])
融合SAR与光学影像:
python复制def despeckle(sar_img, window_size=3):
return cv2.medianBlur(sar_img, window_size)
python复制class CrossModalAttention(nn.Module):
def __init__(self, channels):
super().__init__()
self.query = nn.Conv2d(channels, channels//8, 1)
self.key = nn.Conv2d(channels, channels//8, 1)
self.value = nn.Conv2d(channels, channels, 1)
def forward(self, optical, sar):
B, C, H, W = optical.shape
q = self.query(optical).view(B, -1, H*W)
k = self.key(sar).view(B, -1, H*W)
v = self.value(sar).view(B, -1, H*W)
att = torch.softmax(q @ k.transpose(1,2), dim=-1)
return (att @ v).view(B, C, H, W)
实际部署中发现,当处理高原地区影像时,需要额外考虑冰雪覆盖的干扰。我们在后期增加了高程数据作为辅助输入,通过简单的阈值过滤:
python复制def filter_by_dem(pred, dem, max_elevation=4000):
return pred * (dem < max_elevation).float()
对于边缘设备部署,推荐使用LibTorch进行C++封装。一个实用的内存优化技巧是预先分配固定大小的tensor缓冲区,避免频繁内存分配:
cpp复制torch::Tensor buffer = torch::empty({1,3,256,256}, torch::kFloat32);
// 每次推理复用该buffer
memcpy(buffer.data_ptr(), input_data, input_size);