实战指南:用PyTorch和Elastic Weight Consolidation (EWC) 实现一个简单的终身学习模型
实战指南用PyTorch和Elastic Weight Consolidation (EWC) 实现终身学习模型当机器学习模型需要持续适应新任务而不遗忘旧知识时终身学习LifeLong Learning技术便成为关键解决方案。本文将手把手带你实现基于PyTorch和Elastic Weight Consolidation (EWC)的终身学习系统解决模型在增量学习过程中的灾难性遗忘问题。1. 环境准备与核心概念在开始编码前我们需要明确几个关键术语终身学习模型在连续学习多个任务时保留旧任务知识的能力灾难性遗忘神经网络在学习新任务时快速丢失旧任务知识的现象EWC原理通过约束重要参数的变化来保护已有知识推荐使用Python 3.8和PyTorch 1.12环境。安装依赖pip install torch torchvision matplotlib numpyEWC的核心数学公式L(θ) L_new(θ) λΣ_i F_i(θ_i - θ*_i)^2其中L_new是新任务损失函数θ*是旧任务最优参数F是Fisher信息矩阵λ控制正则化强度2. 数据集处理与模型架构我们将使用Split MNIST作为基准数据集将原始10类MNIST分为5个二元分类任务from torchvision import datasets, transforms def get_task_data(task_id): transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) full_set datasets.MNIST(../data, trainTrue, downloadTrue, transformtransform) # 每个任务包含两个数字类别 class_pairs [(0,1), (2,3), (4,5), (6,7), (8,9)] mask (full_set.targets class_pairs[task_id][0]) | (full_set.targets class_pairs[task_id][1]) task_data torch.utils.data.Subset(full_set, torch.where(mask)[0]) task_data.targets (full_set.targets[mask] class_pairs[task_id][1]).long() return task_data模型采用简单的MLP架构import torch.nn as nn class MLP(nn.Module): def __init__(self): super().__init__() self.layers nn.Sequential( nn.Linear(28*28, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 2) # 二元分类 ) def forward(self, x): x x.view(-1, 28*28) return self.layers(x)3. EWC关键实现EWC需要跟踪两个关键量最优参数θ*和Fisher信息矩阵F。以下是核心实现class EWC: def __init__(self, model, lambda_5000): self.model model self.lambda_ lambda_ self.params {n: p.detach().clone() for n, p in model.named_parameters() if p.requires_grad} self.fisher {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad} def compute_fisher(self, dataset, samples200): self.model.eval() dataloader torch.utils.data.DataLoader(dataset, batch_size1, shuffleTrue) for n, p in self.model.named_parameters(): self.fisher[n].zero_() for i, (data, target) in enumerate(dataloader): if i samples: break self.model.zero_grad() output self.model(data) prob torch.softmax(output, dim1)[0, target.item()] prob.backward() for n, p in self.model.named_parameters(): if p.grad is not None: self.fisher[n] p.grad.data ** 2 / samples def penalty(self): loss 0 for n, p in self.model.named_parameters(): loss (self.fisher[n] * (p - self.params[n]) ** 2).sum() return self.lambda_ * loss注意Fisher矩阵估计需要足够样本才能准确实践中建议使用200-500个样本4. 训练流程设计EWC训练与传统训练的主要区别在于损失函数的构造def train_ewc(model, train_loader, optimizer, ewcNone, current_taskNone): model.train() total_loss 0 for data, target in train_loader: optimizer.zero_grad() output model(data) loss F.cross_entropy(output, target) if ewc is not None and current_task is not None: loss ewc.penalty() loss.backward() optimizer.step() total_loss loss.item() return total_loss / len(train_loader)完整训练循环示例def run_experiment(num_tasks5, epochs_per_task5): model MLP() optimizer torch.optim.Adam(model.parameters(), lr0.001) ewc None for task in range(num_tasks): train_data get_task_data(task) train_loader torch.utils.data.DataLoader(train_data, batch_size32, shuffleTrue) # 普通训练当前任务 for epoch in range(epochs_per_task): train_ewc(model, train_loader, optimizer) # 计算Fisher信息并保存当前参数 if task num_tasks - 1: # 最后一个任务不需要 ewc EWC(model) ewc.compute_fisher(train_data)5. 评估与可视化评估模型在所有任务上的表现def evaluate(model, task_id): test_data get_task_data(task_id) test_loader torch.utils.data.DataLoader(test_data, batch_size32, shuffleFalse) model.eval() correct 0 with torch.no_grad(): for data, target in test_loader: output model(data) pred output.argmax(dim1) correct pred.eq(target).sum().item() return correct / len(test_loader.dataset) # 绘制任务准确率矩阵 def plot_acc_matrix(acc_matrix): plt.figure(figsize(8,6)) sns.heatmap(acc_matrix, annotTrue, fmt.2f, cmapYlGnBu, xticklabels[fTask {i} for i in range(5)], yticklabels[fAfter {i} for i in range(5)]) plt.xlabel(Test Task) plt.ylabel(Trained Up To) plt.title(Accuracy Matrix) plt.show()典型结果分析训练任务数Task 0Task 1Task 2Task 3Task 410.980.500.500.500.5020.950.970.510.500.5030.930.960.960.500.5040.910.940.950.970.5050.890.920.930.960.986. 高级技巧与优化学习率调整策略scheduler torch.optim.lr_scheduler.SequentialLR( optimizer, [ torch.optim.lr_scheduler.ConstantLR(optimizer, factor1.0, total_iters5), torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max5) ], milestones[5] )参数重要性自适应def update_ewc_strength(ewc, current_task_perf, previous_tasks_avg_perf): if current_task_perf 0.8: # 新任务表现不佳 ewc.lambda_ * 0.9 # 降低约束强度 elif previous_tasks_avg_perf 0.7: # 旧任务遗忘严重 ewc.lambda_ * 1.1 # 增强约束多任务基准测试在CIFAR-100上测试20个任务的连续学习比较EWC与LwF、SI等方法的遗忘率添加蒸馏损失进一步提升性能实际项目中我发现EWC的超参数λ对最终效果影响极大。经过多次实验当λ在3000-8000范围内时Split MNIST上通常能取得最佳平衡。另一个实用技巧是在计算Fisher矩阵时对最后几层使用更大的采样量因为这些层通常包含更多任务特定信息。