5.2 Hook 函数与 CAM 算法
Hook 函数概念
torch.Tensor.register_hook(hook)
hook(grad)w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
# 保存梯度的 list
a_grad = list()
# 定义 hook 函数,把梯度添加到 list 中
def grad_hook(grad):
a_grad.append(grad)
# 一个张量注册 hook 函数
handle = a.register_hook(grad_hook)
y.backward()
# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
# 查看在 hook 函数里 list 记录的梯度
print("a_grad[0]: ", a_grad[0])
handle.remove()torch.nn.Module.register_forward_hook(hook)

torch.Tensor.register_forward_pre_hook()
torch.Tensor.register_backward_hook()
hook函数实现机制
hook函数实现机制Hook 函数提取网络的特征图

CAM(class activation map, 类激活图)

最后更新于