from torch.utils.data import Dataset # 导入PyTorch的Dataset基类,自定义数据集必须继承它
from PIL import Image # 导入PIL库的Image模块,用于读取和处理图像文件
import os # 导入os库,用于处理文件路径、目录操作等系统相关功能class MyData(Dataset)
: # 定义MyData类,继承自PyTorch的Dataset抽象类# 这是数据集的"构造方法",用于初始化数据集的基本信息(如路径、文件列表等)def __init__(self, root_dir, label_dir): # root_dir:根目录路径;label_dir:标签目录路径# 保存根目录路径到实例变量(方便类内部其他方法调用)self.root_dir = root_dir # 保存标签目录路径到实例变量(标签目录名称通常就是该类别的标签,如"ants"、"bees")self.label_dir = label_dir # 拼接根目录和标签目录,得到图像文件所在的完整目录路径(例如"dataset/train/ants")self.path = os.path.join(self.root_dir, self.label_dir) # 获取该目录下所有文件的名称列表(例如["0013035.jpg", "003454.jpg"...])self.img_path = os.listdir(self.path) # 这是数据集的核心方法,用于根据索引idx获取一个样本(图像+标签)def __getitem__(self, idx): # idx:样本的索引(从0开始)# 根据索引idx从图像名称列表中取出对应的图像文件名(例如第0个是"0013035.jpg")img_name = self.img_path[idx] # 拼接根目录、标签目录、图像文件名,得到该图像的完整路径(例如"dataset/train/ants/0013035.jpg")img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 用PIL的Image.open()方法打开图像文件,得到图像对象(可后续转为PyTorch张量)img = Image.open(img_item_path) # 将标签目录名称作为该图像的标签(例如"ants"目录下的图像标签就是"ants")label = self.label_dir # 返回该索引对应的图像和标签(这是PyTorch要求的格式:(数据, 标签))return img, label # 这是数据集的长度方法,返回数据集中样本的总数量def __len__(self): # 图像名称列表的长度就是样本数量(因为每个文件名对应一个图像)return len(self.img_path) ### 3. 实例化数据集并组合
```python
root_dir = "dataset/train" # 定义根目录路径(存放训练集的总目录)
ant_label_dir = "ants" # 定义"蚂蚁"类别的标签目录名称
bee_label_dir = "bees" # 定义"蜜蜂"类别的标签目录名称# 实例化"蚂蚁"数据集:传入根目录和蚂蚁标签目录,得到只包含蚂蚁图像的数据集
antset = MyData(root_dir, ant_label_dir)
# 实例化"蜜蜂"数据集:传入根目录和蜜蜂标签目录,得到只包含蜜蜂图像的数据集
beeset = MyData(root_dir, bee_label_dir) # 将蚂蚁数据集和蜜蜂数据集合并,得到完整的训练集(PyTorch的Dataset支持用"+"拼接)
trainset = antset + beeset
以上代码可通过pycharm的代码逐步调试,并实时查看代码运行是否无误