关于load_data_fashion_mnist函数运行原理以及运行速度慢解决方案
需要了解具体解决方案的可以直接跳至第三块内容前言在 PyTorch 深度学习入门实践中Fashion-MNIST 时尚服饰数据集是新手必练的经典数据集而load_data_fashion_mnist函数几乎是加载该数据集的标配工具。但很多小伙伴在使用时都会遇到两个痛点函数运行超慢等待时间长不清楚函数底层原理只会照搬调用今天这篇文章就从函数简介、运行原理、卡顿原因 极致优化方案三个维度把这个函数彻底讲透新手也能轻松掌握一、load_data_fashion_mnist 函数核心简介1. 函数定位load_data_fashion_mnist不是 Python 原生函数也不是 PyTorch 官方内置函数而是《动手学深度学习》等经典教材、深度学习项目中高频使用的自定义封装函数。它的核心价值一行代码完成 Fashion-MNIST 数据集的下载、预处理、迭代器生成极大简化新手数据加载流程。2. 标准完整代码可直接复制使用import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms def load_data_fashion_mnist(batch_size, resizeNone): 加载Fashion-MNIST数据集返回训练集和测试集的DataLoader :param batch_size: 每个批次的样本数量 :param resize: 可选将图片resize到指定尺寸如resize64 :return: train_iter训练集迭代器, test_iter测试集迭代器 # 1. 定义数据预处理操作转张量 标准化 trans [transforms.ToTensor()] # 转成PyTorch张量 if resize: trans.insert(0, transforms.Resize(resize)) # 可选resize trans transforms.Compose(trans) # 组合预处理 # 2. 加载数据集自动下载root为存储路径 mnist_train datasets.FashionMNIST( root../data, trainTrue, transformtrans, downloadTrue ) mnist_test datasets.FashionMNIST( root../data, trainFalse, transformtrans, downloadTrue ) # 3. 生成数据迭代器训练集打乱测试集不打乱 train_iter DataLoader( mnist_train, batch_sizebatch_size, shuffleTrue, num_workers0 ) test_iter DataLoader( mnist_test, batch_sizebatch_size, shuffleFalse, num_workers0 ) return train_iter, test_iter二、load_data_fashion_mnist 运行原理解析这个函数本质是三步流水线作业清晰易懂1. 数据集自动下载函数会检测指定路径../data下是否存在 Fashion-MNIST 数据集不存在则自动从官方服务器下载包含 60000 张训练图 10000 张测试图。2. 数据标准化与预处理将原始图片转换为 PyTorch 模型可识别的张量格式支持自定义图片缩放完成数据归一化预处理。3. 生成训练 / 测试迭代器最终返回两个DataLoader迭代器训练集迭代器打乱数据防止模型学习顺序特征测试集迭代器不打乱数据保证评估结果稳定迭代器会按batch_size分批输出数据直接对接模型训练。三、关于load_data_fashion_mnist为什么会慢和如何提高效率这是大家最关心的核心问题。函数运行慢主要分两种场景对应不同解决方案场景 1首次运行极慢 → 数据集在线下载导致问题原因函数默认开启downloadTrue首次运行时会从国外服务器下载数据集网速受限导致等待时间极长。解决方案本地手动下载数据集提前下载 Fashion-MNIST 数据集压缩包4 个文件在项目根目录创建data文件夹再新建FashionMNIST\raw子文件夹将下载的 4 个文件直接放入raw文件夹中再次运行函数会直接读取本地文件场景 2本地已有数据集依旧读取缓慢 → CPU 单核读取导致问题原因原生函数默认num_workers0当数据集在本地但依旧读取速度慢还有一种原因那就是load_data_fashion_mnist函数本身默认的读取是采用cpu单核读取的而大部分代码常常只有一个batch_size参数而没有num_workers参数而这个参数影响有多少个cpu核心参与本次读取工作因此需要将num_workers设置一下推荐是自己cpu核心数量的1/3~2/3。 *window操作系统下不支持修改num_workers 最后Linux和Mac系统下num_workers设置如果出现报错那么可能是你使用的是d2l下的load_data_fashion_mnist早期版本的d2l是不支持设置load_data_fashion_mnist的num_workers而最新的部分版本是支持的。更新指令为pip install --upgrade d2l