# 2.1 DataLoader 与 DataSet

> 本章代码：<https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson2/rmb_classification/>

## 人民币 二分类

实现 1 元人民币和 100 元人民币的图片二分类。前面讲过 PyTorch 的五大模块：数据、模型、损失函数、优化器和迭代训练。

数据模块又可以细分为 4 个部分：

* 数据收集：样本和标签。
* 数据划分：训练集、验证集和测试集
* 数据读取：对应于PyTorch 的 DataLoader。其中 DataLoader 包括 Sampler 和 DataSet。Sampler 的功能是生成索引， DataSet 是根据生成的索引读取样本以及标签。
* 数据预处理：对应于 PyTorch 的 transforms

![](https://image.zhangxiann.com/20220303191543.png)

## DataLoader 与 DataSet

### torch.utils.data.DataLoader()

```
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
```

功能：构建可迭代的数据装载器

* dataset: Dataset 类，决定数据从哪里读取以及如何读取
* batchsize: 批大小
* num\_works:num\_works: 是否多进程读取数据
* sheuffle: 每个 epoch 是否乱序
* drop\_last: 当样本数不能被 batchsize 整除时，是否舍弃最后一批数据

#### Epoch, Iteration, Batchsize

* Epoch: 所有训练样本都已经输入到模型中，称为一个 Epoch
* Iteration: 一批样本输入到模型中，称为一个 Iteration
* Batchsize: 批大小，决定一个 iteration 有多少样本，也决定了一个 Epoch 有多少个 Iteration

假设样本总数有 80，设置 Batchsize 为 8，则共有 $80 \div 8=10$ 个 Iteration。这里 $1 Epoch = 10 Iteration$。

假设样本总数有 86，设置 Batchsize 为 8。如果`drop_last=True`则共有 10 个 Iteration；如果`drop_last=False`则共有 11 个 Iteration。

### torch.utils.data.Dataset

功能：Dataset 是抽象类，所有自定义的 Dataset 都需要继承该类，并且重写`__getitem()__`方法和`__len__()`方法 。`__getitem()__`方法的作用是接收一个索引，返回索引对应的样本和标签，这是我们自己需要实现的逻辑。`__len__()`方法是返回所有样本的数量。

数据读取包含 3 个方面

* 读取哪些数据：每个 Iteration 读取一个 Batchsize 大小的数据，每个 Iteration 应该读取哪些数据。
* 从哪里读取数据：如何找到硬盘中的数据，应该在哪里设置文件路径参数
* 如何读取数据：不同的文件需要使用不同的读取方法和库。

这里的路径结构如下，有两类人民币图片：1 元和 100 元，每一类各有 100 张图片。

* RMB\_data
  * 1
  * 100

首先划分数据集为训练集、验证集和测试集，比例为 8:1:1。

数据划分好后的路径构造如下：

* rmb\_split
  * train
    * 1
    * 100
  * valid
    * 1
    * 100
  * test
    * 1
    * 100

实现读取数据的 Dataset，编写一个`get_img_info()`方法，读取每一个图片的路径和对应的标签，组成一个元组，再把所有的元组作为 list 存放到`self.data_info`变量中，这里需要注意的是标签需要映射到 0 开始的整数: `rmb_label = {"1": 0, "100": 1}`。

```
    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        # data_dir 是训练集、验证集或者测试集的路径
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            # dirs ['1', '100']
            for sub_dir in dirs:
                # 文件列表
                img_names = os.listdir(os.path.join(root, sub_dir))
                # 取出 jpg 结尾的文件
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    # 图片的绝对路径
                    path_img = os.path.join(root, sub_dir, img_name)
                    # 标签，这里需要映射为 0、1 两个类别
                    label = rmb_label[sub_dir]
                    # 保存在 data_info 变量中
                    data_info.append((path_img, int(label)))
        return data_info
```

然后在`Dataset` 的初始化函数中调用`get_img_info()`方法。

```
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform，数据预处理
        """
        # data_info存储所有图片路径和标签，在DataLoader中通过index读取样本
        self.data_info = self.get_img_info(data_dir)
        self.transform = transform
```

然后在`__getitem__()`方法中根据`index` 读取`self.data_info`中路径对应的数据，并在这里做 transform 操作，返回的是样本和标签。

```
    def __getitem__(self, index):
        # 通过 index 读取样本
        path_img, label = self.data_info[index]
        # 注意这里需要 convert('RGB')
        img = Image.open(path_img).convert('RGB')     # 0~255
        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform，转为tensor等等
        # 返回是样本和标签
        return img, label
```

在`__len__()`方法中返回`self.data_info`的长度，即为所有样本的数量。

```
    # 返回所有样本的数量
    def __len__(self):
        return len(self.data_info)
```

在`train_lenet.py`中，分 5 步构建模型。

第 1 步设置数据。首先定义训练集、验证集、测试集的路径，定义训练集和测试集的`transforms`。然后构建训练集和验证集的`RMBDataset`对象，把对应的路径和`transforms`传进去。再构建`DataLoder`，设置 batch\_size，其中训练集设置`shuffle=True`，表示每个 Epoch 都打乱样本。

```
# 构建MyDataset实例train_data = RMBDataset(data_dir=train_dir, transform=train_transform)valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
# 其中训练集设置 shuffle=True，表示每个 Epoch 都打乱样本
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
```

第 2 步构建模型，这里采用经典的 Lenet 图片分类网络。

```
net = LeNet(classes=2)
net.initialize_weights()
```

第 3 步设置损失函数，这里使用交叉熵损失函数。

```
criterion = nn.CrossEntropyLoss()
```

第 4 步设置优化器。这里采用 SGD 优化器。

```
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略
```

第 5 步迭代训练模型，在每一个 epoch 里面，需要遍历 train\_loader 取出数据，每次取得数据是一个 batchsize 大小。这里又分为 4 步。第 1 步进行前向传播，第 2 步进行反向传播求导，第 3 步使用`optimizer`更新权重，第 4 步统计训练情况。每一个 epoch 完成时都需要使用`scheduler`更新学习率，和计算验证集的准确率、loss。

```
for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    # 遍历 train_loader 取数据
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率
    # 每个 epoch 计算验证集得准确率和loss
    ...
    ...
```

我们可以看到每个 iteration，我们是从`train_loader`中取出数据的。

```
def __iter__(self):
    if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)
    else:
        return _MultiProcessingDataLoaderIter(self)
```

这里我们没有设置多进程，会执行`_SingleProcessDataLoaderIter`的方法。我们以`_SingleProcessDataLoaderIter`为例。在`_SingleProcessDataLoaderIter`里只有一个方法`_next_data()`，如下：

```
def _next_data(self):
    index = self._next_index()  # may raise StopIteration
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    if self._pin_memory:
        data = _utils.pin_memory.pin_memory(data)
    return data
```

在该方法中，`self._next_index()`是获取一个 batchsize 大小的 index 列表，代码如下：

```
def _next_index(self):
    return next(self._sampler_iter)  # may raise StopIteration
```

其中调用的`sampler`类的`__iter__()`方法返回 batch\_size 大小的随机 index 列表。

```
def __iter__(self):
    batch = []
    for idx in self.sampler:
        batch.append(idx)
        if len(batch) == self.batch_size:
            yield batch
            batch = []
    if len(batch) > 0 and not self.drop_last:
        yield batch
```

然后再返回看 `dataloader`的`_next_data()`方法：

```
def _next_data(self):
    index = self._next_index()  # may raise StopIteration
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    if self._pin_memory:
        data = _utils.pin_memory.pin_memory(data)
    return data
```

在第二行中调用了`self._dataset_fetcher.fetch(index)`获取数据。这里会调用`_MapDatasetFetcher`中的`fetch()`函数：

```
def fetch(self, possibly_batched_index):
    if self.auto_collation:
        data = [self.dataset[idx] for idx in possibly_batched_index]
    else:
        data = self.dataset[possibly_batched_index]
    return self.collate_fn(data)
```

这里调用了`self.dataset[idx]`，这个函数会调用`dataset.__getitem__()`方法获取具体的数据，所以`__getitem__()`方法是我们必须实现的。我们拿到的`data`是一个 list，每个元素是一个 tunple，每个 tunple 包括样本和标签。所以最后要使用`self.collate_fn(data)`把 data 转换为两个 list，第一个 元素 是样本的batch 形式，形状为 \[16, 3, 32, 32] (16 是 batch size，\[3, 32, 32] 是图片像素)；第二个元素是标签的 batch 形式，形状为 \[16]。

所以在代码中，我们使用`inputs, labels = data`来接收数据。

PyTorch 数据读取流程图

![](https://image.zhangxiann.com/20200521101040.png)\
首先在 for 循环中遍历`DataLoader`，然后根据是否采用多进程，决定使用单进程或者多进程的`DataLoaderIter`。在`DataLoaderIter`里调用`Sampler`生成`Index`的 list，再调用`DatasetFetcher`根据`index`获取数据。在`DatasetFetcher`里会调用`Dataset`的`__getitem__()`方法获取真正的数据。这里获取的数据是一个 list，其中每个元素是 (img, label) 的元组，再使用 `collate_fn()`函数整理成一个 list，里面包含两个元素，分别是 img 和 label 的`tenser`。

下图是我们的训练过程的 loss 曲线：

![](https://image.zhangxiann.com/20200521101613.png)\
**参考资料**

* [深度之眼 PyTorch 框架班](https://ai.deepshare.net/detail/p_5df0ad9a09d37_qYqVmt85/6)

如果你觉得这篇文章对你有帮助，不妨点个赞，让我有更多动力写出好文章。

我的文章会首发在公众号上，欢迎扫码关注我的公众号**张贤同学**。

![](https://image.zhangxiann.com/QRcode_8cm.jpg)


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://pytorch.zhangxiann.com/2-tu-pian-chu-li-yu-shu-ju-jia-zai/2.1-dataloader-yu-dataset.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
