01-PyTorch加载数据初认识(dataset运用)
一、先看整体结构这是一个标准的 PyTorch 自定义数据集模板核心分为 3 个部分类定义 __init__初始化路径和数据列表__getitem__按索引读取单张图片和标签__len__返回数据集总长度二、逐行代码讲解1. 导入依赖python运行from torch.utils.data import Dataset from PIL import Image import osDatasetPyTorch 提供的抽象基类所有自定义数据集都要继承它这样才能被DataLoader识别Image来自 PIL 库用来读取、处理图片os用来拼接文件路径、读取目录下的文件名处理本地文件系统。2. 类定义与初始化方法__init__python运行class MyData(Dataset): def __init__(self, root_dir, label_dir): self.root_dir root_dir self.label_dir label_dir self.path os.path.join(self.root_dir, self.label_dir) self.img_path os.listdir(self.path)class MyData(Dataset)定义一个新的类MyData继承自Datasetdef __init__(self, root_dir, label_dir)类的构造函数创建数据集对象时会自动执行接收两个参数root_dir数据集的根目录比如dataset/trainlabel_dir类别目录比如ants代表蚂蚁的图片文件夹self.root_dir root_dir把根目录保存到实例变量中后续可以在类的其他方法里调用self.label_dir label_dir把类别目录保存到实例变量中self.path os.path.join(self.root_dir, self.label_dir)拼接根目录和类别目录得到完整的图片文件夹路径比如dataset/train/antsself.img_path os.listdir(self.path)读取dataset/train/ants目录下的所有文件名存入self.img_path列表后续可以按索引读取。3. 核心方法__getitem__python运行def __getitem__(self, idx): img_name self.img_path[idx] img_item_path os.path.join(self.root_dir, self.label_dir, img_name) img Image.open(img_item_path) label self.label_dir return img, labeldef __getitem__(self, idx)PyTorch 规定的方法按索引读取数据idx就是索引从 0 开始img_name self.img_path[idx]根据索引idx从self.img_path列表中取出对应的图片文件名img_item_path os.path.join(self.root_dir, self.label_dir, img_name)拼接根目录、类别目录和图片文件名得到单张图片的完整路径比如dataset/train/ants/001.jpgimg Image.open(img_item_path)用 PIL 读取图片得到一个 Image 对象label self.label_dir把类别目录名比如ants作为标签return img, label返回图片和对应的标签后续模型训练时会接收这两个值。4. 长度方法__len__python运行def __len__(self): return len(self.img_path)def __len__(self)PyTorch 规定的方法返回数据集的总样本数return len(self.img_path)self.img_path是图片文件名列表len(self.img_path)就是图片总数比如dataset/train/ants目录下有 124 张图片就返回 124。三、代码执行流程结合你的控制台python运行root_dir dataset/train ants_label_dir ants ants_dataset MyData(root_dir, ants_label_dir)创建MyData对象传入根目录和类别目录自动执行__init__拼接路径、读取图片列表当你调用len(ants_dataset)时会执行__len__返回图片总数当你调用ants_dataset[0]时会执行__getitem__(0)返回第 1 张图片和标签。四、补充说明与小优化标签处理这段代码里直接用label self.label_dir后续训练时模型需要的是数字标签比如ants0、bees1可以改成python运行# 比如 ants 标签设为 0 label 0路径拼接os.path.join是跨平台的Windows、Linux 都能正常拼接路径避免手动写/或\出错遥感影像适配如果你后续要处理.tif格式的遥感影像把Image.open换成rasterio.open即可核心逻辑不变。