5.2 Hook 函数与 CAM 算法
本章代码:
这篇文章主要介绍了如何使用 Hook 函数提取网络中的特征图进行可视化,和 CAM(class activation map, 类激活图)
Hook 函数概念
Hook 函数是在不改变主体的情况下,实现额外功能。由于 PyTorch 是基于动态图实现的,因此在一次迭代运算结束后,一些中间变量如非叶子节点的梯度和特征图,会被释放掉。在这种情况下想要提取和记录这些中间变量,就需要使用 Hook 函数。
PyTorch 提供了 4 种 Hook 函数。
torch.Tensor.register_hook(hook)
功能:注册一个反向传播 hook 函数,仅输入一个参数,为张量的梯度。
hook函数:
hook(grad)参数:
grad:张量的梯度
代码如下:
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
# 保存梯度的 list
a_grad = list()
# 定义 hook 函数,把梯度添加到 list 中
def grad_hook(grad):
a_grad.append(grad)
# 一个张量注册 hook 函数
handle = a.register_hook(grad_hook)
y.backward()
# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
# 查看在 hook 函数里 list 记录的梯度
print("a_grad[0]: ", a_grad[0])
handle.remove()结果如下:
在反向传播结束后,非叶子节点张量的梯度被清空了。而通过hook函数记录的梯度仍然可以查看。
hook函数里面可以修改梯度的值,无需返回也可以作为新的梯度赋值给原来的梯度。代码如下:
结果是:
torch.nn.Module.register_forward_hook(hook)
功能:注册 module 的前向传播hook函数,可用于获取中间的 feature map。
hook函数:
参数:
module:当前网络层
input:当前网络层输入数据
output:当前网络层输出数据
下面代码执行的功能是 $3 \times 3$ 的卷积和 $2 \times 2$ 的池化。我们使用register_forward_hook()记录中间卷积层输入和输出的 feature map。

输出如下:
torch.Tensor.register_forward_pre_hook()
功能:注册 module 的前向传播前的hook函数,可用于获取输入数据。
hook函数:
参数:
module:当前网络层
input:当前网络层输入数据
torch.Tensor.register_backward_hook()
功能:注册 module 的反向传播的hook函数,可用于获取梯度。
hook函数:
参数:
module:当前网络层
input:当前网络层输入的梯度数据
output:当前网络层输出的梯度数据
代码如下:
输出如下:
hook函数实现机制
hook函数实现机制hook函数实现的原理是在module的__call()__函数进行拦截,__call()__函数可以分为 4 个部分:
第 1 部分是实现 _forward_pre_hooks
第 2 部分是实现 forward 前向传播
第 3 部分是实现 _forward_hooks
第 4 部分是实现 _backward_hooks
由于卷积层也是一个module,因此可以记录_forward_hooks。
Hook 函数提取网络的特征图
下面通过hook函数获取 AlexNet 每个卷积层的所有卷积核参数,以形状作为 key,value 对应该层多个卷积核的 list。然后取出每层的第一个卷积核,形状是 [1, in_channle, h, w],转换为 [in_channle, 1, h, w],使用 TensorBoard 进行可视化,代码如下:
使用 TensorBoard 进行可视化如下:

CAM(class activation map, 类激活图)
暂未完成。列出两个参考文章。
参考资料
如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。
我的文章会首发在公众号上,欢迎扫码关注我的公众号张贤同学。

最后更新于
这有帮助吗?