医学图像分割实战手把手教你用PyTorch复现TransUNet附完整代码与数据集处理医学图像分割一直是计算机视觉领域的重要研究方向尤其在临床诊断和治疗规划中发挥着关键作用。传统的U-Net架构虽然在医学图像分割中表现出色但其卷积操作的局部感受野限制了全局信息的捕获能力。而Transformer结构的引入为解决这一问题提供了新的思路。本文将带你从零开始用PyTorch完整复现TransUNet模型并针对医学图像的特殊性进行优化。1. 环境配置与准备工作在开始构建TransUNet之前我们需要确保开发环境配置正确。以下是推荐的配置方案conda create -n transunet python3.8 conda activate transunet pip install torch1.9.0cu111 torchvision0.10.0cu111 -f https://download.pytorch.org/whl/torch_stable.html pip install nibabel scikit-image tqdm tensorboard注意CUDA版本需要与显卡驱动兼容建议使用NVIDIA官方文档检查兼容性。关键依赖库及其作用库名称版本要求主要用途PyTorch≥1.8.0深度学习框架基础nibabel≥3.2.1医学图像格式处理scikit-image≥0.18.0图像预处理工具tensorboard≥2.5.0训练过程可视化对于硬件配置建议至少满足以下条件GPUNVIDIA显卡显存≥8GB处理3D医学图像建议≥16GB内存≥16GB存储SSD硬盘预留≥50GB空间用于数据集缓存2. 医学图像数据预处理医学图像如CT、MRI与自然图像有很大不同需要特殊处理import nibabel as nib import numpy as np def load_nii_file(file_path): 加载NIfTI格式的医学图像 img nib.load(file_path) data img.get_fdata() # 标准化到[0,1]范围 data (data - np.min(data)) / (np.max(data) - np.min(data)) return np.transpose(data, (2, 0, 1)) # 调整为通道优先格式常见医学图像预处理流程重采样统一不同设备采集的图像分辨率窗宽窗位调整突出特定组织的对比度标准化消除不同扫描仪间的差异数据增强旋转、翻转等操作增加数据多样性对于Synapse多器官分割数据集建议采用以下预处理步骤def preprocess_synapse(data, label): # 1. 重采样到统一分辨率(256x256) data resize(data, (256, 256), order3, preserve_rangeTrue) label resize(label, (256, 256), order0, preserve_rangeTrue) # 2. 强度归一化 data (data - data.mean()) / data.std() # 3. 随机数据增强 if np.random.rand() 0.5: data, label random_rotate(data, label, angle_range(-15,15)) if np.random.rand() 0.5: data, label random_flip(data, label) return data, label3. TransUNet模型架构实现TransUNet的核心创新在于将CNN的局部特征提取能力与Transformer的全局建模能力相结合。下面我们分模块实现3.1 混合编码器实现import torch import torch.nn as nn from einops import rearrange class TransformerEncoder(nn.Module): def __init__(self, dim, depth, heads, mlp_dim, dropout0.1): super().__init__() self.layers nn.ModuleList([ TransformerBlock(dim, heads, mlp_dim, dropout) for _ in range(depth) ]) def forward(self, x): for layer in self.layers: x layer(x) return x class HybridEncoder(nn.Module): def __init__(self, in_chans3, embed_dim768, depth12, num_heads12): super().__init__() # CNN特征提取部分(修改后的ResNet50) self.cnn_backbone ModifiedResNet50() # Transformer部分 self.proj nn.Conv2d(1024, embed_dim, kernel_size1) self.transformer TransformerEncoder( dimembed_dim, depthdepth, headsnum_heads, mlp_dim3072 ) def forward(self, x): # CNN特征提取 features self.cnn_backbone(x) # [B, 1024, 14, 14] # 投影到Transformer维度 x self.proj(features) # [B, 768, 14, 14] # 序列化并添加位置编码 b, c, h, w x.shape x rearrange(x, b c h w - b (h w) c) x x self.pos_embedding # Transformer编码 x self.transformer(x) # 恢复空间维度 x rearrange(x, b (h w) c - b c h w, hh, ww) return x, features3.2 解码器实现解码器采用类似U-Net的结构但加入了Transformer编码特征的融合class DecoderBlock(nn.Module): def __init__(self, in_channels, out_channels, skip_channels0): super().__init__() self.up nn.ConvTranspose2d(in_channels, out_channels, kernel_size2, stride2) self.conv nn.Sequential( nn.Conv2d(out_channelsskip_channels, out_channels, 3, padding1), nn.GroupNorm(16, out_channels), nn.ReLU(inplaceTrue), nn.Conv2d(out_channels, out_channels, 3, padding1), nn.GroupNorm(16, out_channels), nn.ReLU(inplaceTrue) ) def forward(self, x, skipNone): x self.up(x) if skip is not None: x torch.cat([x, skip], dim1) return self.conv(x) class TransUNetDecoder(nn.Module): def __init__(self, num_classes, embed_dim768): super().__init__() self.decoder1 DecoderBlock(embed_dim, 512, skip_channels512) self.decoder2 DecoderBlock(512, 256, skip_channels256) self.decoder3 DecoderBlock(256, 128, skip_channels128) self.decoder4 DecoderBlock(128, 64, skip_channels64) self.final nn.Conv2d(64, num_classes, kernel_size1) def forward(self, x, features): # features是CNN编码器各阶段输出 x self.decoder1(x, features[3]) # 1/16 - 1/8 x self.decoder2(x, features[2]) # 1/8 - 1/4 x self.decoder3(x, features[1]) # 1/4 - 1/2 x self.decoder4(x, features[0]) # 1/2 - 1/1 return self.final(x)4. 训练策略与技巧医学图像分割训练需要特别注意以下几点4.1 损失函数选择class DiceLoss(nn.Module): def __init__(self, smooth1e-5): super().__init__() self.smooth smooth def forward(self, pred, target): pred pred.sigmoid() intersection (pred * target).sum() union pred.sum() target.sum() dice (2. * intersection self.smooth) / (union self.smooth) return 1 - dice class CombinedLoss(nn.Module): def __init__(self, alpha0.5): super().__init__() self.alpha alpha self.dice DiceLoss() self.bce nn.BCEWithLogitsLoss() def forward(self, pred, target): return self.alpha * self.dice(pred, target) (1-self.alpha) * self.bce(pred, target)4.2 学习率调度策略def get_optimizer(model, lr1e-4, weight_decay1e-4): param_groups [ {params: [p for n, p in model.named_parameters() if backbone in n], lr: lr/10}, {params: [p for n, p in model.named_parameters() if backbone not in n], lr: lr} ] return torch.optim.AdamW(param_groups, weight_decayweight_decay) def get_scheduler(optimizer, num_epochs): return torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_maxnum_epochs, eta_min1e-6 )4.3 训练过程中的关键技巧混合精度训练减少显存占用加快训练速度梯度裁剪防止梯度爆炸早停机制基于验证集性能停止训练模型EMA使用指数移动平均提升模型稳定性from torch.cuda.amp import GradScaler, autocast scaler GradScaler() for epoch in range(epochs): model.train() for images, masks in train_loader: optimizer.zero_grad() with autocast(): outputs model(images) loss criterion(outputs, masks) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() scheduler.step()5. 模型评估与结果可视化5.1 评估指标实现def compute_iou(pred, target, num_classes): ious [] pred pred.argmax(1) for cls in range(num_classes): pred_inds pred cls target_inds target cls intersection (pred_inds target_inds).sum().float() union (pred_inds | target_inds).sum().float() if union 0: ious.append(float(nan)) else: ious.append((intersection / union).item()) return np.nanmean(ious) def compute_dice(pred, target, num_classes): dices [] pred pred.argmax(1) for cls in range(num_classes): pred_inds pred cls target_inds target cls intersection (pred_inds target_inds).sum().float() if (pred_inds.sum() target_inds.sum()) 0: dices.append(float(nan)) else: dices.append((2 * intersection) / (pred_inds.sum() target_inds.sum()).item()) return np.nanmean(dices)5.2 结果可视化import matplotlib.pyplot as plt def visualize_results(image, mask, pred, save_pathNone): fig, axes plt.subplots(1, 3, figsize(15, 5)) axes[0].imshow(image, cmapgray) axes[0].set_title(Input Image) axes[0].axis(off) axes[1].imshow(mask, cmapjet) axes[1].set_title(Ground Truth) axes[1].axis(off) axes[2].imshow(pred.argmax(0), cmapjet) axes[2].set_title(Prediction) axes[2].axis(off) if save_path: plt.savefig(save_path, bbox_inchestight, dpi300) plt.close()在实际项目中我们发现TransUNet在小型器官如胰腺的分割上表现尤为突出这得益于Transformer捕获的全局上下文信息。一个常见的调优方向是调整编码器中CNN与Transformer的比例找到适合特定数据集的平衡点。