当前位置: 首页 > news >正文

PyTorch Weight Decay 技术指南

PyTorch Weight Decay 技术指南

目录

  1. 摘要
  2. 概念与理论
    • 2.1 核心概念
    • 2.2 与 L2 正则化的关系
    • 2.3 核心作用
  3. PyTorch 实践
    • 3.1 如何设置 λ(权重衰减系数)
    • 3.2 不同架构的常见设置
    • 3.3 PyTorch 实现方式
    • 3.4 高级技巧

1. 摘要

Weight Decay(权重衰减)是深度学习中重要的正则化技术,通过在训练过程中对模型权重施加惩罚,防止过拟合,提升模型泛化能力。

2. 概念与理论

2.1 核心概念

Weight Decay是一种正则化技术,在损失函数中添加与权重大小相关的惩罚项,鼓励模型学习更小的权重值,得到更简单、平滑的模型。

带Weight Decay的总损失函数:

L_total = L_original + λ/2 * ||w||²

其中λ是权重衰减系数,控制惩罚项权重:λ越大,对大幅值权重的惩罚越重,模型越简单。

2.2 与 L2 正则化的关系

在标准随机梯度下降(SGD)中,Weight Decay完全等价于L2正则化。

但在使用自适应优化器(如Adam, AdamW)时,传统实现方式会导致不等价。Adam等优化器会为每个参数计算自适应学习率,如果直接将L2正则项加到损失函数中,会像处理普通梯度一样处理正则项的梯度,导致正则化效果被扭曲。

AdamW(Adam with Weight Decay)解决了这个问题,将Weight Decay项从损失函数中解耦出来,直接在权重更新时添加,而不影响梯度计算。

AdamW的更新规则:
w = w - lr * d(L_original)/dw - lr * λ * w

关键区别:AdamW中的λ * w项不参与梯度、一阶矩、二阶矩的计算,是独立的衰减项,效果更纯粹稳定。

2.3 核心作用

防止过拟合:通过惩罚大的权重,限制模型复杂度,使其无法完美"记忆"训练数据中的噪声和细节。

提升泛化能力:更简单的模型在未见过的数据上通常表现更好。

3. PyTorch 实践

3.1 如何设置 λ(权重衰减系数)

λ是关键超参数,需要仔细调整。没有通用值。

典型范围:λ通常在1e-4到1e-2之间(0.0001到0.01)。

  • 1e-4是常用且安全的起始点
  • 1e-3和1e-4是最常见的选择
  • 1e-2是非常强的衰减,只适用于特定场景

调整策略

  • 从默认值开始:λ = 1e-4或1e-3
  • 与学习率协同调整:通常需要将两者一起搜索
  • 观察训练与验证曲线:
    • 欠拟合(训练误差和验证误差都很大):减小λ或设为0
    • 过拟合(训练误差很小,验证误差很大):增大λ

3.2 不同架构的常见设置

计算机视觉(CNN):常用1e-4量级。ResNet、VGG等经典网络通常使用此值。

自然语言处理(Transformer):AdamW是标准优化器。常用值为0.01或0.1。

其他领域:RNN/LSTM通常从1e-4开始尝试。

3.3 PyTorch 实现方式

方式一:使用SGD优化器

optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9,weight_decay=1e-4)

方式二:使用AdamW优化器(推荐)

optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4,weight_decay=0.01)

注意:避免使用Adam + L2,会导致自适应学习率问题。

3.4 高级技巧

不对偏置和归一化层进行衰减

只对权重应用Weight Decay,不对偏置和层归一化、批归一化参数应用。

# 示例:将权重和偏置参数分开
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():if any(nd in name for nd in ["bias", "norm.weight", "norm.bias"]):# 偏置和Norm层的参数不衰减no_decay_params.append(param)else:# 其他权重参数衰减decay_params.append(param)optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay': 0.01},{'params': no_decay_params, 'weight_decay': 0.0}
], lr=1e-4)
http://www.hskmm.com/?act=detail&tid=9867

相关文章:

  • AUTOSAR进阶图解==>AUTOSAR_SWS_PDURouter - 实践
  • getDefaultMidwayLoggerConfig报错;解决方法。
  • js获取浏览器语言,以及调用谷歌翻译api翻译成相应的内容
  • 总结RocketMQ中的常见问题
  • The 2025 ICPC Asia EC Regionals Online Contest (II)
  • C++线上练习
  • Python实现Elman RNN与混合RNN神经网络对航空客运量、啤酒产量、电力产量时间序列数据预测可视化对比
  • 4G/Wi-Fi/以太网三网合一,智能融合通信实战案例集
  • 关于介绍自己的第一篇随笔
  • 深入解析:N32G43x Flash 驱动移植与封装实践
  • Backblaze上如何传大文件
  • 解题报告-老逗找基友 (friends)
  • Caused by: java.lang.ClassNotFoundException: org.apache.rocketmq.remoting.common.RemotingUtil
  • VAE In JAX【个人记录向】
  • BLE蓝牙配网双模式实操:STA+SoftAP技术原理与避坑指南
  • 第58天:RCE代码amp;命令执行amp;过滤绕过amp;异或无字符amp;无回显方案amp;黑白盒挖掘
  • 057-Web攻防-SSRFDemo源码Gopher项目等
  • 060-WEB攻防-PHP反序列化POP链构造魔术方法流程漏洞触发条件属性修改
  • 059-Web攻防-XXE安全DTD实体复现源码等
  • 061-WEB攻防-PHP反序列化原生类TIPSCVE绕过漏洞属性类型特征
  • 051-Web攻防-文件安全目录安全测试源码等
  • Dilworth定理及其在算法题中的应用
  • 050-WEB攻防-PHP应用文件包含LFIRFI伪协议编码算法无文件利用黑白盒
  • error: xxxxx does not have a commit checked out
  • 049-WEB攻防-文件上传存储安全OSS对象分站解析安全解码还原目录执行
  • 云原生周刊:MetalBear 融资、Chaos Mesh 漏洞、Dapr 1.16 与 AI 平台新趋势
  • AI一周资讯 250913-250919
  • 045-WEB攻防-PHP应用SQL二次注入堆叠执行DNS带外功能点黑白盒条件-cnblog
  • linux 命令语句
  • 用 Kotlin 实现英文数字验证码识别