class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
在 PyTorch 中,对 tensor 执行 detach() 会返回一个新的 tensor,该 tensor 与原始 tensor 共享数据,但不会参与反向传播计算。具体来说:
- 共享数据:新
tensor与原始tensor共享相同的数据,修改其中一个会影响另一个。 - 断开计算图:新
tensor不会记录计算历史,因此不会在反向传播中计算梯度。 - 梯度计算:原始
tensor的梯度计算不受影响,仍可正常进行。
示例代码
import torch
# 创建一个需要梯度的 tensor
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# 对 x 进行一些操作
y = x * 2
# 对 y 执行 detach
z = y.detach()
# 打印结果
print("y:", y)
print("z:", z)
# 进行反向传播
y.sum().backward()
# 打印 x 的梯度
print("x.grad:", x.grad)
输出
y: tensor([2., 4., 6.], grad_fn=<MulBackward0>)
z: tensor([2., 4., 6.])
x.grad: tensor([2., 2., 2.])
解释
y是通过x计算得到的,保留了计算历史。z是y的detach()结果,与y共享数据但不参与反向传播。- 反向传播后,
x的梯度正常计算,而z不影响梯度计算。
总结
detach() 用于从计算图中分离 tensor,常用于冻结部分模型参数或避免不必要的梯度计算。
如果只是对 GPU 上的 tensor 执行 .detach() 而不调用 .cpu(),那么 z 和 y 仍然会共享数据,因此对 z 的修改会反映到 y 上。
原因
.detach()的作用:
- 返回一个新的
tensor,与原始tensor共享数据。 - 新
tensor不会参与反向传播,但仍然与原始tensor共享底层存储。
- 共享数据:
- 如果
z是通过y.detach()创建的,那么z和y共享相同的数据存储。 - 对
z的修改会直接反映到y上,因为它们指向同一块内存。
验证代码
import torch
# 创建一个在 GPU 上的 tensor,并启用梯度计算
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True).cuda()
# 对 x 进行一些操作
y = x * 2
# 执行 .detach(),不调用 .cpu()
z = y.detach()
# 修改 z
z[0] = 100.0
# 打印 y 和 z
print("y:", y) # y 的值被修改
print("z:", z) # z 的值被修改
输出
y: tensor([100., 4., 6.], device='cuda:0', grad_fn=<MulBackward0>)
z: tensor([100., 4., 6.], device='cuda:0')
解释
y:- 虽然
y仍然在 GPU 上,并且保留了计算历史(grad_fn),但它的值被修改为[100., 4., 6.]。 - 这是因为
z和y共享数据。 z:z的值被修改为[100., 4., 6.]。- 由于
z和y共享数据,y的值也被同步修改。
总结
- 如果只调用
.detach(): z和y共享数据,对z的修改会反映到y上。- 如果调用
.detach().cpu(): z会被移动到 CPU,并且与y不再共享数据,对z的修改不会影响y。
因此,是否调用 .cpu() 是决定 z 和 y 是否共享数据的关键。