5.2 Hook 函数与 CAM 算法
这篇文章主要介绍了如何使用 Hook 函数提取网络中的特征图进行可视化,和 CAM(class activation map, 类激活图)

Hook 函数概念

Hook 函数是在不改变主体的情况下,实现额外功能。由于 PyTorch 是基于动态图实现的,因此在一次迭代运算结束后,一些中间变量如非叶子节点的梯度和特征图,会被释放掉。在这种情况下想要提取和记录这些中间变量,就需要使用 Hook 函数。
PyTorch 提供了 4 种 Hook 函数。

torch.Tensor.register_hook(hook)

功能:注册一个反向传播 hook 函数,仅输入一个参数,为张量的梯度。
hook函数:
1
hook(grad)
Copied!
参数:
  • grad:张量的梯度
代码如下:
1
w = torch.tensor([1.], requires_grad=True)
2
x = torch.tensor([2.], requires_grad=True)
3
a = torch.add(w, x)
4
b = torch.add(w, 1)
5
y = torch.mul(a, b)
6
7
# 保存梯度的 list
8
a_grad = list()
9
10
# 定义 hook 函数,把梯度添加到 list 中
11
def grad_hook(grad):
12
a_grad.append(grad)
13
14
# 一个张量注册 hook 函数
15
handle = a.register_hook(grad_hook)
16
17
y.backward()
18
19
# 查看梯度
20
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
21
# 查看在 hook 函数里 list 记录的梯度
22
print("a_grad[0]: ", a_grad[0])
23
handle.remove()
Copied!
结果如下:
1
gradient: tensor([5.]) tensor([2.]) None None None
2
a_grad[0]: tensor([2.])
Copied!
在反向传播结束后,非叶子节点张量的梯度被清空了。而通过hook函数记录的梯度仍然可以查看。
hook函数里面可以修改梯度的值,无需返回也可以作为新的梯度赋值给原来的梯度。代码如下:
1
w = torch.tensor([1.], requires_grad=True)
2
x = torch.tensor([2.], requires_grad=True)
3
a = torch.add(w, x)
4
b = torch.add(w, 1)
5
y = torch.mul(a, b)
6
7
a_grad = list()
8
9
def grad_hook(grad):
10
grad *= 2
11
return grad*3
12
13
handle = w.register_hook(grad_hook)
14
15
y.backward()
16
17
# 查看梯度
18
print("w.grad: ", w.grad)
19
handle.remove()
Copied!
结果是:
1
w.grad: tensor([30.])
Copied!

torch.nn.Module.register_forward_hook(hook)

功能:注册 module 的前向传播hook函数,可用于获取中间的 feature map。
hook函数:
1
hook(module, input, output)
Copied!
参数:
  • module:当前网络层
  • input:当前网络层输入数据
  • output:当前网络层输出数据
下面代码执行的功能是 $3 \times 3$ 的卷积和 $2 \times 2$ 的池化。我们使用register_forward_hook()记录中间卷积层输入和输出的 feature map。
1
class Net(nn.Module):
2
def __init__(self):
3
super(Net, self).__init__()
4
self.conv1 = nn.Conv2d(1, 2, 3)
5
self.pool1 = nn.MaxPool2d(2, 2)
6
7
def forward(self, x):
8
x = self.conv1(x)
9
x = self.pool1(x)
10
return x
11
12
def forward_hook(module, data_input, data_output):
13
fmap_block.append(data_output)
14
input_block.append(data_input)
15
16
# 初始化网络
17
net = Net()
18
net.conv1.weight[0].detach().fill_(1)
19
net.conv1.weight[1].detach().fill_(2)
20
net.conv1.bias.data.detach().zero_()
21
22
# 注册hook
23
fmap_block = list()
24
input_block = list()
25
net.conv1.register_forward_hook(forward_hook)
26
27
# inference
28
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
29
output = net(fake_img)
30
31
32
# 观察
33
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
34
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
35
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))
Copied!
输出如下:
1
output shape: torch.Size([1, 2, 1, 1])
2
output value: tensor([[[[ 9.]],
3
[[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)
4
feature maps shape: torch.Size([1, 2, 2, 2])
5
output value: tensor([[[[ 9., 9.],
6
[ 9., 9.]],
7
[[18., 18.],
8
[18., 18.]]]], grad_fn=<ThnnConv2DBackward>)
9
input shape: torch.Size([1, 1, 4, 4])
10
input value: (tensor([[[[1., 1., 1., 1.],
11
[1., 1., 1., 1.],
12
[1., 1., 1., 1.],
13
[1., 1., 1., 1.]]]]),)
Copied!

torch.Tensor.register_forward_pre_hook()

功能:注册 module 的前向传播前的hook函数,可用于获取输入数据。
hook函数:
1
hook(module, input)
Copied!
参数:
  • module:当前网络层
  • input:当前网络层输入数据

torch.Tensor.register_backward_hook()

功能:注册 module 的反向传播的hook函数,可用于获取梯度。
hook函数:
1
hook(module, grad_input, grad_output)
Copied!
参数:
  • module:当前网络层
  • input:当前网络层输入的梯度数据
  • output:当前网络层输出的梯度数据
代码如下:
1
class Net(nn.Module):
2
def __init__(self):
3
super(Net, self).__init__()
4
self.conv1 = nn.Conv2d(1, 2, 3)
5
self.pool1 = nn.MaxPool2d(2, 2)
6
7
def forward(self, x):
8
x = self.conv1(x)
9
x = self.pool1(x)
10
return x
11
12
def forward_hook(module, data_input, data_output):
13
fmap_block.append(data_output)
14
input_block.append(data_input)
15
16
def forward_pre_hook(module, data_input):
17
print("forward_pre_hook input:{}".format(data_input))
18
19
def backward_hook(module, grad_input, grad_output):
20
print("backward hook input:{}".format(grad_input))
21
print("backward hook output:{}".format(grad_output))
22
23
# 初始化网络
24
net = Net()
25
net.conv1.weight[0].detach().fill_(1)
26
net.conv1.weight[1].detach().fill_(2)
27
net.conv1.bias.data.detach().zero_()
28
29
# 注册hook
30
fmap_block = list()
31
input_block = list()
32
net.conv1.register_forward_hook(forward_hook)
33
net.conv1.register_forward_pre_hook(forward_pre_hook)
34
net.conv1.register_backward_hook(backward_hook)
35
36
# inference
37
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
38
output = net(fake_img)
39
40
loss_fnc = nn.L1Loss()
41
target = torch.randn_like(output)
42
loss = loss_fnc(target, output)
43
loss.backward()
Copied!
输出如下:
1
forward_pre_hook input:(tensor([[[[1., 1., 1., 1.],
2
[1., 1., 1., 1.],
3
[1., 1., 1., 1.],
4
[1., 1., 1., 1.]]]]),)
5
backward hook input:(None, tensor([[[[0.5000, 0.5000, 0.5000],
6
[0.5000, 0.5000, 0.5000],
7
[0.5000, 0.5000, 0.5000]]],
8
[[[0.5000, 0.5000, 0.5000],
9
[0.5000, 0.5000, 0.5000],
10
[0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000]))
11
backward hook output:(tensor([[[[0.5000, 0.0000],
12
[0.0000, 0.0000]],
13
[[0.5000, 0.0000],
14
[0.0000, 0.0000]]]]),)
Copied!

hook函数实现机制

hook函数实现的原理是在module__call()__函数进行拦截,__call()__函数可以分为 4 个部分:
  • 第 1 部分是实现 _forward_pre_hooks
  • 第 2 部分是实现 forward 前向传播
  • 第 3 部分是实现 _forward_hooks
  • 第 4 部分是实现 _backward_hooks
由于卷积层也是一个module,因此可以记录_forward_hooks
1
def __call__(self, *input, **kwargs):
2
# 第 1 部分是实现 _forward_pre_hooks
3
for hook in self._forward_pre_hooks.values():
4
result = hook(self, input)
5
if result is not None:
6
if not isinstance(result, tuple):
7
result = (result,)
8
input = result
9
10
# 第 2 部分是实现 forward 前向传播
11
if torch._C._get_tracing_state():
12
result = self._slow_forward(*input, **kwargs)
13
else:
14
result = self.forward(*input, **kwargs)
15
16
# 第 3 部分是实现 _forward_hooks
17
for hook in self._forward_hooks.values():
18
hook_result = hook(self, input, result)
19
if hook_result is not None:
20
result = hook_result
21
22
# 第 4 部分是实现 _backward_hooks
23
if len(self._backward_hooks) > 0:
24
var = result
25
while not isinstance(var, torch.Tensor):
26
if isinstance(var, dict):
27
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
28
else:
29
var = var[0]
30
grad_fn = var.grad_fn
31
if grad_fn is not None:
32
for hook in self._backward_hooks.values():
33
wrapper = functools.partial(hook, self)
34
functools.update_wrapper(wrapper, hook)
35
grad_fn.register_hook(wrapper)
36
return result
Copied!

Hook 函数提取网络的特征图

下面通过hook函数获取 AlexNet 每个卷积层的所有卷积核参数,以形状作为 key,value 对应该层多个卷积核的 list。然后取出每层的第一个卷积核,形状是 [1, in_channle, h, w],转换为 [in_channle, 1, h, w],使用 TensorBoard 进行可视化,代码如下:
1
writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")
2
3
# 数据
4
path_img = "imgs/lena.png" # your path to image
5
normMean = [0.49139968, 0.48215827, 0.44653124]
6
normStd = [0.24703233, 0.24348505, 0.26158768]
7
8
norm_transform = transforms.Normalize(normMean, normStd)
9
img_transforms = transforms.Compose([
10
transforms.Resize((224, 224)),
11
transforms.ToTensor(),
12
norm_transform
13
])
14
15
img_pil = Image.open(path_img).convert('RGB')
16
if img_transforms is not None:
17
img_tensor = img_transforms(img_pil)
18
img_tensor.unsqueeze_(0) # chw --> bchw
19
20
# 模型
21
alexnet = models.alexnet(pretrained=True)
22
23
# 注册hook
24
fmap_dict = dict()
25
for name, sub_module in alexnet.named_modules():
26
27
if isinstance(sub_module, nn.Conv2d):
28
key_name = str(sub_module.weight.shape)
29
fmap_dict.setdefault(key_name, list())
30
# 由于AlexNet 使用 nn.Sequantial 包装,所以 name 的形式是:features.0 features.1
31
n1, n2 = name.split(".")
32
33
def hook_func(m, i, o):
34
key_name = str(m.weight.shape)
35
fmap_dict[key_name].append(o)
36
37
alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)
38
39
# forward
40
output = alexnet(img_tensor)
41
42
# add image
43
for layer_name, fmap_list in fmap_dict.items():
44
fmap = fmap_list[0]# 取出第一个卷积核的参数
45
fmap.transpose_(0, 1) # 把 BCHW 转换为 CBHW
46
47
nrow = int(np.sqrt(fmap.shape[0]))
48
fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
49
writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)
Copied!
使用 TensorBoard 进行可视化如下:

CAM(class activation map, 类激活图)

暂未完成。列出两个参考文章。
参考资料
如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。
我的文章会首发在公众号上,欢迎扫码关注我的公众号张贤同学
最近更新 1yr ago