PyTorch Weight Decay 技术指南
目录
- 摘要
- 概念与理论
- 2.1 核心概念
- 2.2 与 L2 正则化的关系
- 2.3 核心作用
- PyTorch 实践
- 3.1 如何设置 λ(权重衰减系数)
- 3.2 不同架构的常见设置
- 3.3 PyTorch 实现方式
- 3.4 高级技巧
1. 摘要
Weight Decay(权重衰减)是深度学习中重要的正则化技术,通过在训练过程中对模型权重施加惩罚,防止过拟合,提升模型泛化能力。
2. 概念与理论
2.1 核心概念
Weight Decay是一种正则化技术,在损失函数中添加与权重大小相关的惩罚项,鼓励模型学习更小的权重值,得到更简单、平滑的模型。
带Weight Decay的总损失函数:
L_total = L_original + λ/2 * ||w||²
其中λ是权重衰减系数,控制惩罚项权重:λ越大,对大幅值权重的惩罚越重,模型越简单。
2.2 与 L2 正则化的关系
在标准随机梯度下降(SGD)中,Weight Decay完全等价于L2正则化。
但在使用自适应优化器(如Adam, AdamW)时,传统实现方式会导致不等价。Adam等优化器会为每个参数计算自适应学习率,如果直接将L2正则项加到损失函数中,会像处理普通梯度一样处理正则项的梯度,导致正则化效果被扭曲。
AdamW(Adam with Weight Decay)解决了这个问题,将Weight Decay项从损失函数中解耦出来,直接在权重更新时添加,而不影响梯度计算。
AdamW的更新规则:
w = w - lr * d(L_original)/dw - lr * λ * w
关键区别:AdamW中的λ * w项不参与梯度、一阶矩、二阶矩的计算,是独立的衰减项,效果更纯粹稳定。
2.3 核心作用
防止过拟合:通过惩罚大的权重,限制模型复杂度,使其无法完美"记忆"训练数据中的噪声和细节。
提升泛化能力:更简单的模型在未见过的数据上通常表现更好。
3. PyTorch 实践
3.1 如何设置 λ(权重衰减系数)
λ是关键超参数,需要仔细调整。没有通用值。
典型范围:λ通常在1e-4到1e-2之间(0.0001到0.01)。
- 1e-4是常用且安全的起始点
- 1e-3和1e-4是最常见的选择
- 1e-2是非常强的衰减,只适用于特定场景
调整策略:
- 从默认值开始:λ = 1e-4或1e-3
- 与学习率协同调整:通常需要将两者一起搜索
- 观察训练与验证曲线:
- 欠拟合(训练误差和验证误差都很大):减小λ或设为0
- 过拟合(训练误差很小,验证误差很大):增大λ
3.2 不同架构的常见设置
计算机视觉(CNN):常用1e-4量级。ResNet、VGG等经典网络通常使用此值。
自然语言处理(Transformer):AdamW是标准优化器。常用值为0.01或0.1。
其他领域:RNN/LSTM通常从1e-4开始尝试。
3.3 PyTorch 实现方式
方式一:使用SGD优化器
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9,weight_decay=1e-4)
方式二:使用AdamW优化器(推荐)
optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4,weight_decay=0.01)
注意:避免使用Adam + L2,会导致自适应学习率问题。
3.4 高级技巧
不对偏置和归一化层进行衰减:
只对权重应用Weight Decay,不对偏置和层归一化、批归一化参数应用。
# 示例:将权重和偏置参数分开
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():if any(nd in name for nd in ["bias", "norm.weight", "norm.bias"]):# 偏置和Norm层的参数不衰减no_decay_params.append(param)else:# 其他权重参数衰减decay_params.append(param)optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay': 0.01},{'params': no_decay_params, 'weight_decay': 0.0}
], lr=1e-4)