Pytorch

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
                (28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

在 PyTorch 中,对 tensor 执行 detach() 会返回一个新的 tensor,该 tensor 与原始 tensor 共享数据,但不会参与反向传播计算。具体来说:

  1. 共享数据:新 tensor 与原始 tensor 共享相同的数据,修改其中一个会影响另一个。
  2. 断开计算图:新 tensor 不会记录计算历史,因此不会在反向传播中计算梯度。
  3. 梯度计算:原始 tensor 的梯度计算不受影响,仍可正常进行。

示例代码

import torch

# 创建一个需要梯度的 tensor
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

# 对 x 进行一些操作
y = x * 2

# 对 y 执行 detach
z = y.detach()

# 打印结果
print("y:", y)
print("z:", z)

# 进行反向传播
y.sum().backward()

# 打印 x 的梯度
print("x.grad:", x.grad)

输出

y: tensor([2., 4., 6.], grad_fn=<MulBackward0>)
z: tensor([2., 4., 6.])
x.grad: tensor([2., 2., 2.])

解释

  • y 是通过 x 计算得到的,保留了计算历史。
  • zydetach() 结果,与 y 共享数据但不参与反向传播。
  • 反向传播后,x 的梯度正常计算,而 z 不影响梯度计算。

总结

detach() 用于从计算图中分离 tensor,常用于冻结部分模型参数或避免不必要的梯度计算。

如果只是对 GPU 上的 tensor 执行 .detach() 而不调用 .cpu(),那么 zy 仍然会共享数据,因此对 z 的修改会反映到 y 上。


原因

  1. .detach() 的作用
  • 返回一个新的 tensor,与原始 tensor 共享数据。
  • tensor 不会参与反向传播,但仍然与原始 tensor 共享底层存储。
  1. 共享数据
  • 如果 z 是通过 y.detach() 创建的,那么 zy 共享相同的数据存储。
  • z 的修改会直接反映到 y 上,因为它们指向同一块内存。

验证代码

import torch

# 创建一个在 GPU 上的 tensor,并启用梯度计算
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True).cuda()

# 对 x 进行一些操作
y = x * 2

# 执行 .detach(),不调用 .cpu()
z = y.detach()

# 修改 z
z[0] = 100.0

# 打印 y 和 z
print("y:", y)  # y 的值被修改
print("z:", z)  # z 的值被修改

输出

y: tensor([100.,   4.,   6.], device='cuda:0', grad_fn=<MulBackward0>)
z: tensor([100.,   4.,   6.], device='cuda:0')

解释

  • y:
  • 虽然 y 仍然在 GPU 上,并且保留了计算历史(grad_fn),但它的值被修改为 [100., 4., 6.]
  • 这是因为 zy 共享数据。
  • z:
  • z 的值被修改为 [100., 4., 6.]
  • 由于 zy 共享数据,y 的值也被同步修改。

总结

  • 如果只调用 .detach()
  • zy 共享数据,对 z 的修改会反映到 y 上。
  • 如果调用 .detach().cpu()
  • z 会被移动到 CPU,并且与 y 不再共享数据,对 z 的修改不会影响 y

因此,是否调用 .cpu() 是决定 zy 是否共享数据的关键