7.2 模型 Finetune

本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/finetune_resnet18.py

这篇文章主要介绍了模型的 Finetune。

迁移学习:把在 source domain 任务上的学习到的模型应用到 target domain 的任务。

Finetune 就是一种迁移学习的方法。比如做人脸识别,可以把 ImageNet 看作 source domain,人脸数据集看作 target domain。通常来说 source domain 要比 target domain 大得多。可以利用 ImageNet 训练好的网络应用到人脸识别中。

对于一个模型,通常可以分为前面的 feature extractor (卷积层)和后面的 classifier,在 Finetune 时,通常不改变 feature extractor 的权值,也就是冻结卷积层;并且改变最后一个全连接层的输出来适应目标任务,训练后面 classifier 的权值,这就是 Finetune。通常 target domain 的数据比较小,不足以训练全部参数,容易导致过拟合,因此不改变 feature extractor 的权值。

Finetune 步骤如下:

  1. 获取预训练模型的参数

  2. 使用load_state_dict()把参数加载到模型中

  3. 修改输出层

  4. 固定 feature extractor 的参数。这部分通常有 2 种做法:

    1. 固定卷积层的预训练参数。可以设置requires_grad=False或者lr=0

    2. 可以通过params_group给 feature extractor 设置一个较小的学习率

下面微调 ResNet18,用于蜜蜂和蚂蚁图片的二分类。训练集每类数据各 120 张,验证集每类数据各 70 张图片。

数据下载地址:http://download.pytorch.org/tutorial/hymenoptera_data.zip

预训练好的模型参数下载地址:http://download.pytorch.org/models/resnet18-5c106cde.pth

不使用 Finetune

第一次我们首先不使用 Finetune,而是从零开始训练模型,这时只需要修改全连接层即可:

输出如下:

训练了 25 个 epoch 后的准确率为:70.59%。

训练的 loss 曲线如下:

使用 Finetune

然后我们把下载的模型参数加载到模型中:

不冻结卷积层

这时我们不冻结卷积层,所有层都是用相同的学习率,输出如下:

训练了 25 个 epoch 后的准确率为:96.08%。

训练的 loss 曲线如下:

冻结卷积层

设置requires_grad=False

这里先冻结所有参数,然后再替换全连接层,相当于冻结了卷积层的参数:

这里不提供实验结果。

设置学习率为 0

这里把卷积层的学习率设置为 0,需要在优化器里设置不同的学习率。首先获取全连接层参数的地址,然后使用 filter 过滤不属于全连接层的参数,也就是保留卷积层的参数;接着设置优化器的分组学习率,传入一个 list,包含 2 个元素,每个元素是字典,对应 2 个参数组。其中卷积层的学习率设置为 全连接层的 0.1 倍。

这里不提供实验结果。

使用分组学习率

这里不冻结卷积层,而是对卷积层使用较小的学习率,对全连接层使用较大的学习率,需要在优化器里设置不同的学习率。首先获取全连接层参数的地址,然后使用 filter 过滤不属于全连接层的参数,也就是保留卷积层的参数;接着设置优化器的分组学习率,传入一个 list,包含 2 个元素,每个元素是字典,对应 2 个参数组。其中卷积层的学习率设置为 全连接层的 0.1 倍。

这里不提供实验结果。

使用 GPU 的 tips

PyTorch 模型使用 GPU,可以分为 3 步:

  1. 首先获取 device:device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  2. 把模型加载到 device:model.to(device)

  3. 在 data_loader 取数据的循环中,把每个 mini-batch 的数据和 label 加载到 device:inputs, labels = inputs.to(device), labels.to(device)

参考资料

如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。

我的文章会首发在公众号上,欢迎扫码关注我的公众号张贤同学

最后更新于

这有帮助吗?