# 7.1 模型保存与加载

> 本章代码：
>
> * <https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py>
> * <https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_load.py>
> * <https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/checkpoint_resume.py>
> * <https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/save_checkpoint.py>

这篇文章主要介绍了序列化与反序列化，以及 PyTorch 中的模型保存于加载的两种方式，模型的断点续训练。

## 序列化与反序列化

模型在内存中是以对象的逻辑结构保存的，但是在硬盘中是以二进制流的方式保存的。

* 序列化是指将内存中的数据以二进制序列的方式保存到硬盘中。PyTorch 的模型保存就是序列化。
* 反序列化是指将硬盘中的二进制序列加载到内存中，得到模型的对象。PyTorch 的模型加载就是反序列化。

## PyTorch 中的模型保存与加载

### torch.save

```
torch.save(obj, f, pickle_module, pickle_protocol=2, _use_new_zipfile_serialization=False)
```

主要参数：

* obj：保存的对象，可以是模型。也可以是 dict。因为一般在保存模型时，不仅要保存模型，还需要保存优化器、此时对应的 epoch 等参数。这时就可以用 dict 包装起来。
* f：输出路径

其中模型保存还有两种方式：

#### 保存整个 Module

这种方法比较耗时，保存的文件大

```
torch.savev(net, path)
```

#### 只保存模型的参数

推荐这种方法，运行比较快，保存的文件比较小

```
state_sict = net.state_dict()
torch.savev(state_sict, path)
```

下面是保存 LeNet 的例子。在网络初始化中，把权值都设置为 2020，然后保存模型。

```
import torch
import numpy as np
import torch.nn as nn
from common_tools import set_seed


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(2020)


net = LeNet2(classes=2019)

# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])

path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
```

运行完之后，文件夹中生成了\`\`\`model.pkl\`\`和`model_state_dict.pkl`，分别保存了整个网络和网络的参数

### torch.load

```
torch.load(f, map_location=None, pickle_module, **pickle_load_args)
```

主要参数：

* f：文件路径
* map\_location：指定存在 CPU 或者 GPU。

加载模型也有两种方式

#### 加载整个 Module

如果保存的时候，保存的是整个模型，那么加载时就加载整个模型。这种方法不需要事先创建一个模型对象，也不用知道模型的结构，代码如下：

```
path_model = "./model.pkl"
net_load = torch.load(path_model)

print(net_load)
```

输出如下：

```
LeNet2(
  (features): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=2019, bias=True)
  )
)
```

#### 只加载模型的参数

如果保存的时候，保存的是模型的参数，那么加载时就参数。这种方法需要事先创建一个模型对象，再使用模型的`load_state_dict()`方法把参数加载到模型中，代码如下：

```
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)
net_new = LeNet2(classes=2019)

print("加载前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加载后: ", net_new.features[0].weight[0, ...])
```

## 模型的断点续训练

在训练过程中，可能由于某种意外原因如断点等导致训练终止，这时需要重新开始训练。断点续练是在训练过程中每隔一定次数的 epoch 就保存**模型的参数和优化器的参数**，这样如果意外终止训练了，下次就可以重新加载最新的**模型参数和优化器的参数**，在这个基础上继续训练。

下面的代码中，每隔 5 个 epoch 就保存一次，保存的是一个 dict，包括模型参数、优化器的参数、epoch。然后在 epoch 大于 5 时，就`break`模拟训练意外终止。关键代码如下：

```
    if (epoch+1) % checkpoint_interval == 0:

        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)
```

在 epoch 大于 5 时，就`break`模拟训练意外终止

```
    if epoch > 5:
        print("训练意外中断...")
        break
```

断点续训练的恢复代码如下：

```
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch
```

需要注意的是，还要设置`scheduler.last_epoch`参数为保存的 epoch。模型训练的起始 epoch 也要修改为保存的 epoch。

**参考资料**

* [深度之眼 PyTorch 框架班](https://ai.deepshare.net/detail/p_5df0ad9a09d37_qYqVmt85/6)

如果你觉得这篇文章对你有帮助，不妨点个赞，让我有更多动力写出好文章。

我的文章会首发在公众号上，欢迎扫码关注我的公众号**张贤同学**。

![](https://image.zhangxiann.com/QRcode_8cm.jpg)
