grad_tensors: 多梯度权重。当有多个 loss 混合需要计算梯度时,设置每个 loss 的权重。
retain_graph 参数
代码示例
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
# y=(x+w)*(w+1)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
# 第一次执行梯度求导
y.backward()
print(w.grad)
# 第二次执行梯度求导,出错
y.backward()
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
# y=(x+w)*(w+1)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
# 第一次求导,设置 retain_graph=True,保留计算图
y.backward(retain_graph=True)
print(w.grad)
# 第二次求导成功
y.backward()
grad_tensors 参数
代码示例:
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y0 = torch.mul(a, b) # y0 = (x+w) * (w+1)
y1 = torch.add(a, b) # y1 = (x+w) + (w+1) dy1/dw = 2
# 把两个 loss 拼接都到一起
loss = torch.cat([y0, y1], dim=0) # [y0, y1]
# 设置两个 loss 的权重: y0 的权重是 1,y1 的权重是 2
grad_tensors = torch.tensor([1., 2.])
loss.backward(gradient=grad_tensors) # gradient 传入 torch.autograd.backward()中的grad_tensors
# 最终的 w 的导数由两部分组成。∂y0/∂w * 1 + ∂y1/∂w * 2
print(w.grad)
结果为:
tensor([9.])
该 loss 由两部分组成:$y_{0}$ 和 $y_{1}$。其中 $\frac{\partial y_{0}}{\partial w}=5$,$\frac{\partial y_{1}}{\partial w}=2$。而 grad_tensors 设置两个 loss 对 w 的权重分别为 1 和 2。因此最终 w 的梯度为:$\frac{\partial y_{0}}{\partial w} \times 1+ \frac{\partial y_{1}}{\partial w} \times 2=9$。
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
# 进行 4 次反向传播求导,每次最后都没有清零
for i in range(4):
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
y.backward()
print(w.grad)
for i in range(4):
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
y.backward()
print(w.grad)
# 每次都把梯度清零
# w.grad.zero_()
依赖于叶子节点的节点,requires_grad 属性默认为 True。
叶子节点不可执行 inplace 操作。
以加法来说,inplace 操作有a += x,a.add_(x),改变后的值和原来的值内存地址是同一个。非inplace 操作有a = a + x,a.add(x),改变后的值和原来的值内存地址不是同一个。
代码示例:
print("非 inplace 操作")
a = torch.ones((1, ))
print(id(a), a)
# 非 inplace 操作,内存地址不一样
a = a + torch.ones((1, ))
print(id(a), a)
print("inplace 操作")
a = torch.ones((1, ))
print(id(a), a)
# inplace 操作,内存地址一样
a += torch.ones((1, ))
print(id(a), a)
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
# y = (x + w) * (w + 1)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
# 在反向传播之前 inplace 改变了 w 的值,再执行 backward() 会报错
w.add_(1)
y.backward()