# 7.1 模型保存与加载

> 本章代码：
>
> * <https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py>
> * <https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_load.py>
> * <https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/checkpoint_resume.py>
> * <https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/save_checkpoint.py>

这篇文章主要介绍了序列化与反序列化，以及 PyTorch 中的模型保存于加载的两种方式，模型的断点续训练。

## 序列化与反序列化

模型在内存中是以对象的逻辑结构保存的，但是在硬盘中是以二进制流的方式保存的。

* 序列化是指将内存中的数据以二进制序列的方式保存到硬盘中。PyTorch 的模型保存就是序列化。
* 反序列化是指将硬盘中的二进制序列加载到内存中，得到模型的对象。PyTorch 的模型加载就是反序列化。

## PyTorch 中的模型保存与加载

### torch.save

```
torch.save(obj, f, pickle_module, pickle_protocol=2, _use_new_zipfile_serialization=False)
```

主要参数：

* obj：保存的对象，可以是模型。也可以是 dict。因为一般在保存模型时，不仅要保存模型，还需要保存优化器、此时对应的 epoch 等参数。这时就可以用 dict 包装起来。
* f：输出路径

其中模型保存还有两种方式：

#### 保存整个 Module

这种方法比较耗时，保存的文件大

```
torch.savev(net, path)
```

#### 只保存模型的参数

推荐这种方法，运行比较快，保存的文件比较小

```
state_sict = net.state_dict()
torch.savev(state_sict, path)
```

下面是保存 LeNet 的例子。在网络初始化中，把权值都设置为 2020，然后保存模型。

```
import torch
import numpy as np
import torch.nn as nn
from common_tools import set_seed


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(2020)


net = LeNet2(classes=2019)

# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])

path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
```

运行完之后，文件夹中生成了\`\`\`model.pkl\`\`和`model_state_dict.pkl`，分别保存了整个网络和网络的参数

### torch.load

```
torch.load(f, map_location=None, pickle_module, **pickle_load_args)
```

主要参数：

* f：文件路径
* map\_location：指定存在 CPU 或者 GPU。

加载模型也有两种方式

#### 加载整个 Module

如果保存的时候，保存的是整个模型，那么加载时就加载整个模型。这种方法不需要事先创建一个模型对象，也不用知道模型的结构，代码如下：

```
path_model = "./model.pkl"
net_load = torch.load(path_model)

print(net_load)
```

输出如下：

```
LeNet2(
  (features): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=2019, bias=True)
  )
)
```

#### 只加载模型的参数

如果保存的时候，保存的是模型的参数，那么加载时就参数。这种方法需要事先创建一个模型对象，再使用模型的`load_state_dict()`方法把参数加载到模型中，代码如下：

```
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)
net_new = LeNet2(classes=2019)

print("加载前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加载后: ", net_new.features[0].weight[0, ...])
```

## 模型的断点续训练

在训练过程中，可能由于某种意外原因如断点等导致训练终止，这时需要重新开始训练。断点续练是在训练过程中每隔一定次数的 epoch 就保存**模型的参数和优化器的参数**，这样如果意外终止训练了，下次就可以重新加载最新的**模型参数和优化器的参数**，在这个基础上继续训练。

下面的代码中，每隔 5 个 epoch 就保存一次，保存的是一个 dict，包括模型参数、优化器的参数、epoch。然后在 epoch 大于 5 时，就`break`模拟训练意外终止。关键代码如下：

```
    if (epoch+1) % checkpoint_interval == 0:

        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)
```

在 epoch 大于 5 时，就`break`模拟训练意外终止

```
    if epoch > 5:
        print("训练意外中断...")
        break
```

断点续训练的恢复代码如下：

```
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch
```

需要注意的是，还要设置`scheduler.last_epoch`参数为保存的 epoch。模型训练的起始 epoch 也要修改为保存的 epoch。

**参考资料**

* [深度之眼 PyTorch 框架班](https://ai.deepshare.net/detail/p_5df0ad9a09d37_qYqVmt85/6)

如果你觉得这篇文章对你有帮助，不妨点个赞，让我有更多动力写出好文章。

我的文章会首发在公众号上，欢迎扫码关注我的公众号**张贤同学**。

![](https://image.zhangxiann.com/QRcode_8cm.jpg)


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://pytorch.zhangxiann.com/7-mo-xing-qi-ta-cao-zuo/7.1-mo-xing-bao-cun-yu-jia-zai.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
