- Dataset类
- DataLoader类
- Dataset和DataLoader之间的联系
Dataset类
首先关于 Dataset ,直接翻译过来就是数据集。当然就目前我学习的感悟来看这个东西翻译的完全没错,关于这个 Dataset 类中的一些对象就是数据集。
所谓数据集,就是数据的集合,这也是我们之后训练神经网络必不可少的原料。
Dataset 类是数据加载的核心组件之一,它主要用于封装数据集,提供统一的接口来访问数据样本。Dataset 位于 torch.utils.data 模块中,是一个抽象基类,用户通常需要通过继承它来实现自定义的数据集类。
它的核心作用有3点:
1.封装数据:将数据(如图像、文本、标签等)组织成可迭代访问的形式
2.提供统一接口:通过 getitem 方法按索引获取样本,通过 len 方法获取数据集大小
3.方便与 DataLoader 结合使用:实现数据的批量加载、打乱顺序、多进程加载等功能
如果需要自己创建自定义数据集,需要继承 Dataset 并实现以下两个核心方法:
__len__()
:返回数据集的样本总数__getitem__(index)
:根据索引返回一个样本(通常是数据和对应的标签)
下面是一段示例代码:
import torch
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data, labels):"""初始化数据集:param data: 输入数据:param labels: 数据对应的标签"""self.data = dataself.labels = labelsdef __len__(self):"""返回数据集大小"""return len(self.data)def __getitem__(self, idx):"""根据索引返回样本"""sample_data = self.data[idx]sample_label = self.labels[idx]# 可以在这里进行数据预处理# 例如:转换为Tensor、归一化等return {'data': torch.tensor(sample_data, dtype=torch.float32),'label': torch.tensor(sample_label, dtype=torch.long)}# 使用示例
if __name__ == "__main__":# 模拟数据data = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]labels = [0, 1, 0, 1]# 创建数据集实例dataset = CustomDataset(data, labels)# 查看数据集大小print(f"数据集大小: {len(dataset)}")# 获取单个样本sample = dataset[0]print(f"第一个样本数据: {sample['data']}")print(f"第一个样本标签: {sample['label']}")
当然pytorch中还提供了一些预定义的Dataset子类(工具类),可以根据自己的实际需求方便处理常见的数据格式(当然我没用过),先列在下面吧,大家有个印象就行(反正ai会解决一切问题)。
- TensorDataset:用于包装张量数据,当数据以张量形式存在时非常方便
- ConcatDataset:用于拼接多个数据集
- ChainDataset:用于按顺序链式访问多个数据集
DataLoader类
DataLoader!Dataset的最好搭子,它们之间搭配使用可以让Dataset里面的相关数据在神经网络之中得到最大的利用。
在 PyTorch 中,DataLoader 是与 Dataset 配合使用的关键组件,主要负责数据的批量加载、打乱顺序、多进程加载等功能,是构建高效数据管道的核心工具。它位于 torch.utils.data 模块中,能够将 Dataset 提供的单个样本转换为模型训练所需的批量数据。
DataLoader的作用
1.批量处理数据:将单个样本组合成批次(batch),匹配模型训练时的批量输入需求
2.数据打乱:在每个 epoch 开始时随机打乱数据顺序,提高模型泛化能力
3.多进程加载:使用多进程并行加载数据,解决数据加载成为训练瓶颈的问题
4.内存优化:通过迭代器方式加载数据,避免一次性将所有数据载入内存
5.支持自定义_collate:灵活处理非结构化数据(如长度不一的文本、图像等) --->这个我暂时还没有用到
DataLoader 的构造函数需要传入一个 Dataset 对象,并通过参数配置批量加载的规则。这么说起来可能有点抽象,还是一样的,示例代码如下:
from torch.utils.data import DataLoaderdataloader = DataLoader(dataset, # 传入Dataset对象batch_size=32, # 批大小,每次加载32个样本shuffle=True, # 是否在每个epoch打乱数据num_workers=4, # 多进程加载的进程数collate_fn=None, # 自定义批量处理函数(可选)drop_last=False # 是否丢弃最后一个不完整的批次
)
DataLoader中的相关参数介绍:
- dataset:必须传入的 Dataset 对象,是数据的来源
- batch_size:每个批次包含的样本数(默认值为 1)。根据模型大小和显存容量调整,例如显存较大时可设为 64 或 128。
- shuffle:布尔值,指定是否在每个 epoch 开始时打乱数据顺序。训练集通常设为 True,验证集 / 测试集设为 False。
- num_workers:用于数据加载的子进程数(默认值为 0,表示使用主进程加载)。适当增大可加速数据加载,但不宜超过 CPU 核心数。
- collate_fn:自定义的函数,用于将多个样本组合成一个批次(默认会自动处理张量的拼接)。当样本结构复杂(如长度不一的序列)时,需要自定义该函数。
- drop_last:布尔值,若数据集大小不能被 batch_size 整除,是否丢弃最后一个不完整的批次(默认值为 False)。训练时可设为 True 避免批次大小不一致。
- pin_memory:布尔值,若设为 True,会将加载的数据复制到 CUDA 固定内存中,加速后续 GPU 传输(仅在使用 GPU 时有效)。
Dataset和DataLoader之间的联系
DataLoader 和 Dataset 是 PyTorch 中数据加载 pipeline 的两个核心组件,它们紧密配合,共同构成了高效数据处理 pipeline。两者的关系可以概括为:Dataset 负责数据的封装和单个样本的获取,DataLoader 负责将单个样本组织成批量数据并高效加载。
具体的联系和它们之间的分工
1.数据来源与封装:Dataset 的角色
Dataset 是数据的 “源头”,它的核心作用是:
- 封装原始数据(如图像文件、文本、标签等),提供统一的接口
- 通过 getitem(index) 方法按索引返回单个样本(如一张图像和对应的标签)
- 通过 len() 方法返回数据集的总样本数
简单说,Dataset 就像一个 “仓库管理员”,知道数据在哪里、有多少,并且能按编号取出指定的物品。
2.批量处理与加载:DataLoader 的角色
DataLoader 是 Dataset 的 “搬运工”,它接收 Dataset 提供的单个样本,进行以下操作:
- 将多个样本组合成批次(batch) 数据(如一次返回 32 个样本)
- 支持数据打乱(shuffle=True),避免模型学习到数据顺序规律
- 通过多进程(num_workers)并行加载数据,提升效率
- 处理不规则数据的批量拼接(通过 collate_fn) --->是的这个我没用过
也就是说,DataLoader 会从 Dataset 中 “定期取货”,并将零散的 “货物” 打包成适合模型输入的 “标准包裹”。
总的来说这两者,Dataset解决了数据源的问题,DataLoader解决了怎么搞笑利用数据源的问题。它们结合起来(我们联合!),才能两面包夹芝士一起共同高效、有质量的构建出一个好的网络模型。