boxmoe_header_banner_img

Hello! 欢迎来到悠悠畅享网!

文章导读

PyTorch中冻结中间层参数的策略与实践


avatar
作者 2025年8月22日 13

PyTorch中冻结中间层参数的策略与实践

本文深入探讨了在pytorch神经网络中冻结特定中间层参数的两种主要方法:使用torch.no_grad()上下文管理器和设置参数的requires_grad=False属性。通过实验对比,我们揭示了这两种方法在梯度回传机制上的关键差异,并明确指出在需要精确冻结特定层而允许其他层更新的场景下,应优先采用requires_grad=False策略,以实现灵活高效的模型训练。

导言:理解层冻结的需求

在深度学习模型训练中,我们有时需要冻结网络中的某些层,即阻止这些层的参数在反向传播过程中被更新。这在多种场景下非常有用,例如:

  • 迁移学习(Transfer Learning):使用预训练模型作为特征提取器,只微调顶层分类器。
  • 模型稳定性:在训练的某些阶段,固定部分层以稳定训练过程。
  • 实验控制:隔离特定层的影响,以便更好地理解模型行为。

然而,如何正确地冻结一个中间层,同时确保其前后层能够正常更新,是一个常见的疑问。本文将详细探讨两种常用的方法,并通过实验分析它们的实际效果。

方法一:使用 torch.no_grad() 上下文管理器

torch.no_grad() 是PyTorch提供的一个上下文管理器,其作用是在其内部执行的代码块中,禁用梯度计算。这意味着,在该代码块中创建的任何张量都不会追踪其操作历史,也不会计算梯度。

考虑一个简单的三层线性网络:lin0 -> lin1 -> lin2。如果我们的目标是冻结 lin1,同时允许 lin0 和 lin2 更新,一个直观的想法是在 lin1 的前向传播中使用 torch.no_grad():

import torch import torch.nn as nn  class SimpleModelNoGrad(nn.Module):     def __init__(self):         super(SimpleModelNoGrad, self).__init__()         self.lin0 = nn.Linear(1, 2)         self.lin1 = nn.Linear(2, 2)         self.lin2 = nn.Linear(2, 10)      def forward(self, x):         x = self.lin0(x)         # 在lin1的前向传播中使用no_grad         with torch.no_grad():             x = self.lin1(x)         x = self.lin2(x)         return x  # 实例化模型 model_nograd = SimpleModelNoGrad()  # 记录初始参数 initial_lin0_weight = model_nograd.lin0.weight.clone() initial_lin1_weight = model_nograd.lin1.weight.clone() initial_lin2_weight = model_nograd.lin2.weight.clone()  # 模拟训练步骤 optimizer = torch.optim.SGD(model_nograd.parameters(), lr=0.01) input_data = torch.randn(1, 1) target = torch.randint(0, 10, (1,)) loss_fn = nn.CrossEntropyLoss()  print("--- 使用 torch.no_grad() 策略 ---") print("初始 lin0 权重:n", initial_lin0_weight) print("初始 lin1 权重:n", initial_lin1_weight) print("初始 lin2 权重:n", initial_lin2_weight)  # 进行一次前向传播、反向传播和优化 optimizer.zero_grad() output = model_nograd(input_data) loss = loss_fn(output, target) loss.backward() optimizer.step()  # 检查参数变化 print("n更新后 lin0 权重:n", model_nograd.lin0.weight) print("更新后 lin1 权重:n", model_nograd.lin1.weight) print("更新后 lin2 权重:n", model_nograd.lin2.weight)  print("nlin0 权重是否改变:", not torch.equal(initial_lin0_weight, model_nograd.lin0.weight)) print("lin1 权重是否改变:", not torch.equal(initial_lin1_weight, model_nograd.lin1.weight)) print("lin2 权重是否改变:", not torch.equal(initial_lin2_weight, model_nograd.lin2.weight))

实验结果分析: 在上述实验中,你会发现 lin0、lin1 和 lin2 的参数都没有更新。这是因为 torch.no_grad() 不仅阻止了 lin1 内部的梯度计算,更重要的是,它切断了从 lin2 到 lin1 再到 lin0 的整个梯度回传路径。一旦某个张量(lin1 的输出)在 no_grad 块中生成,它就没有梯度历史,因此其上游的 lin0 也无法接收到梯度信号,从而导致所有相关参数都无法更新。

结论: torch.no_grad() 适用于完全禁用某个计算分支的梯度计算,例如在推理阶段或特征提取阶段。它不适用于需要精确冻结中间层同时允许其上游层更新的场景。

方法二:设置参数的 requires_grad=False 属性

更精确地冻结特定层的方法是直接修改其参数的 requires_grad 属性。PyTorch中的每个张量都有一个 requires_grad 属性,默认为 True。如果将其设置为 False,PyTorch将不会为该张量计算梯度,并且在反向传播时,任何依赖于该张量的操作的梯度都不会传播到该张量。

为了冻结 lin1,我们需要在模型定义之后,但在优化器初始化之前,将其所有参数(权重和偏置)的 requires_grad 属性设置为 False。

import torch import torch.nn as nn  class SimpleModelRequiresGrad(nn.Module):     def __init__(self):         super(SimpleModelRequiresGrad, self).__init__()         self.lin0 = nn.Linear(1, 2)         self.lin1 = nn.Linear(2, 2)         self.lin2 = nn.Linear(2, 10)      def forward(self, x):         x = self.lin0(x)         x = self.lin1(x)         x = self.lin2(x)         return x  # 实例化模型 model_req_grad = SimpleModelRequiresGrad()  # 在优化器定义之前,冻结lin1的参数 for param in model_req_grad.lin1.parameters():     param.requires_grad = False  # 记录初始参数 initial_lin0_weight_rg = model_req_grad.lin0.weight.clone() initial_lin1_weight_rg = model_req_grad.lin1.weight.clone() initial_lin2_weight_rg = model_req_grad.lin2.weight.clone()  # 只有requires_grad=True的参数才会被优化器考虑 optimizer_rg = torch.optim.SGD(Filter(Lambda p: p.requires_grad, model_req_grad.parameters()), lr=0.01) input_data_rg = torch.randn(1, 1) target_rg = torch.randint(0, 10, (1,)) loss_fn_rg = nn.CrossEntropyLoss()  print("n--- 使用 requires_grad=False 策略 ---") print("初始 lin0 权重:n", initial_lin0_weight_rg) print("初始 lin1 权重:n", initial_lin1_weight_rg) print("初始 lin2 权重:n", initial_lin2_weight_rg)  # 进行一次前向传播、反向传播和优化 optimizer_rg.zero_grad() output_rg = model_req_grad(input_data_rg) loss_rg = loss_fn_rg(output_rg, target_rg) loss_rg.backward() optimizer_rg.step()  # 检查参数变化 print("n更新后 lin0 权重:n", model_req_grad.lin0.weight) print("更新后 lin1 权重:n", model_req_grad.lin1.weight) print("更新后 lin2 权重:n", model_req_grad.lin2.weight)  print("nlin0 权重是否改变:", not torch.equal(initial_lin0_weight_rg, model_req_grad.lin0.weight)) print("lin1 权重是否改变:", not torch.equal(initial_lin1_weight_rg, model_req_grad.lin1.weight)) print("lin2 权重是否改变:", not torch.equal(initial_lin2_weight_rg, model_req_grad.lin2.weight))

实验结果分析: 通过这种方法,你会发现 lin0 和 lin2 的参数得到了更新,而 lin1 的参数保持不变。这是因为 lin1 的 requires_grad 被设置为 False,其梯度不会被计算,也不会参与优化。但 lin2 的梯度会正常计算并回传到 lin1 的输入,由于 lin1 的参数不需要梯度,梯度会继续回传到 lin0,从而使得 lin0 也能正常更新。

关键注意事项:

  • 优化器参数过滤:在创建优化器时,务必只传入 requires_grad=True 的参数。optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01) 是一种常见且推荐的做法。如果直接传入 model.parameters(),优化器会尝试为所有参数分配内存,即使它们不会更新,这可能导致不必要的资源消耗,虽然最终它们不会被更新。
  • 批量操作:对于包含多个子模块的复杂模型,可以通过循环遍历子模块或使用 named_parameters() 来批量设置 requires_grad。

总结与最佳实践

特性/方法 torch.no_grad() param.requires_grad = False
作用范围 局部,作用于上下文管理器内的所有计算操作 全局,作用于特定参数本身
梯度回传 切断梯度回传路径,其上游和自身均无法更新 允许梯度通过,但不会为 requires_grad=False 的参数计算和存储梯度,其上游层可正常更新
适用场景 推理阶段、性能评估、特征提取等不需要梯度计算的场景 冻结特定层进行迁移学习、微调、或实验控制等需要精确控制参数更新的场景
推荐程度 不推荐用于精确冻结中间层并允许前后层更新的场景 强烈推荐用于精确冻结特定层的场景

综上所述,当您需要在PyTorch中冻结一个中间层,同时确保其前后层能够正常训练和更新时,设置目标层的参数 requires_grad=False 是最准确和推荐的方法。torch.no_grad() 更适用于完全禁用某个计算路径的梯度追踪,它会影响到整个计算链条,导致意外的冻结效果。理解这两种机制的差异,对于高效和准确地进行模型训练至关重要。



评论(已关闭)

评论已关闭