当前位置: 首页 > news >正文

模型训练中 平均损失值和平均准确率的深入理解

aver_loss

总损失的计算

对于求平均损失来说 需要先求总损失
而求总损失 就需要求一个批次中的损失

对于一个bs来说 损失的计算是利用
loss=criterion(out,labels)计算得出
而criterion 使用的nn.crossentropy
得出来的损失值 已经是对这一个bs传入的所有样本取过平均值了
所以得出来的loss是当前bs的aver_loss

上面标亮的这段话 是求损失值的关键,也是后面两种方法的基础。
则total_loss+=loss 就计算出总损失了。

对于aver_loss 是可以有两种处理方式的。

方法一:累加“总损失”,最后除以“总样本数”

这是更精确、更标准的方法,也是 PyTorch 官方教程中常见的方式。

  1. 循环内的操作:

    Python

    running_loss += loss.item() * inputs.size(0)
    
    • loss.item():这是 PyTorch CrossEntropyLoss 默认返回的一个批次 (batch) 的平均损失

    • inputs.size(0):这是当前批次中的样本数量(也就是 batch_size)。

    • loss.item() * inputs.size(0):用“平均损失”乘以“样本数”,我们得到的实际上是这个批次的“总损失”(即损失值的加和)。

    • running_loss += ...:所以,running_loss 累加的是所有批次的总损失之和,也就是整个 epoch 见过的所有样本的损失总和

  2. 循环外的操作:

    Python

    epoch_loss = running_loss / dataset_size[phase]
    
    • 因为 running_loss所有样本的损失总和,所以我们理应除以所有样本的总数量 (dataset_size[phase]),来得到最精确的“平均到每个样本的损失”
  • 优点:这种方法可以精确地处理最后一个批次样本数不足的情况(当数据集总数不能被 batch_size 整除时),因为 inputs.size(0) 会自动适应最后一个批次的实际大小。

方法二:累加“平均损失”,最后除以“总批次数”(您提出的方式)

您的这个逻辑也是完全正确的!它代表了另一种计算思路。

  1. 要使用您的计算方法,循环内的操作应该是:

    Python

    running_loss += loss.item() 
    
    • 这里,我们累加的是每个批次的“平均损失”running_loss 最终会变成所有批次的平均损失之和
  2. 循环外的操作(如您所写):

    Python

    aver_loss = running_loss / len(dataloaders[phase])
    
    • 因为 running_loss所有批次平均损失的和,所以我们理应除以总的批次数 (len(dataloaders[phase])),来得到“每个批次的平均损失的平均值”
  • 优点:实现起来非常直观。

  • 微小缺点:当最后一个批次样本数不足时,它在计算最终平均值时,给予了这个不完整的批次的“平均损失”与其他完整批次相同的权重,理论上会引入微小的计算偏差。但在实践中,当数据集很大时,这点偏差几乎可以忽略不计。

accuracy

对于准确率来说,他是在每一个批次(bs)中
使用_,preds=torch.max(outputs,1)
torch.max的使用参考https://www.cnblogs.com/zhuzhucheng/p/19109039先求出分类的类别。
然后调用torch.sum(preds == labels.data) 求出正确预测的总数。
preds==labels.data 返回的是一个bool数组。
torch.sum则是把bool数组的true视为1 false视为0 求和 最后返回一个整数

在每一个epoch训练前定义需定义total_acc=0
在每个批次累加正确的数量到total_acc上
再一个轮次所有批次训练完后 total_acc/样本总数(也就是参考损失值计算中的dataset_size[phase]) 即为正确率。

http://www.hskmm.com/?act=detail&tid=15707

相关文章:

  • torch.max函数在分类问题中的使用 学习
  • godot3.6字典遍历
  • 国产DevOps工具链崛起:Gitee领衔的本土化技术生态全景解读
  • 安装 elasticsearch-9.1.4的 IK分词器
  • react性能优化
  • 从研发效能到知识中枢:Gitee Wiki如何重塑企业知识管理范式
  • Gitee DevSecOps平台:军工软件研发的智能化革命
  • 杆状病毒表达系统为何成为蛋白表达首选
  • 日记3
  • Gitee如何重塑中国开发者的代码托管体验
  • 模块化面向对象 2章
  • css `isolation: isolate` - 详解
  • Debezium + Kafka + Flink/Doris Stream Load 实时数仓
  • Gitee DevOps平台:中国企业数字化转型的代码管理新范式
  • Ansible + Docker 部署 Zookeeper 集群
  • 幂运算与航班中转的奇妙旅行:探索算法世界的两极 - 实践
  • Gemini CLI 配置问题
  • 本土化与全球化博弈下的项目管理工具选型:Gitee如何为中国企业破局?
  • 论Linux安装后需要进行的配置
  • 51单片机-驱动DS1302时钟芯片模块教程 - 实践
  • tomato WP复盘
  • SQLite的并发问题
  • 域渗透靶场-vulntarget-a综合靶场
  • 数组和链表读取、插入、删除以及查找的区别
  • day 09 课程
  • 在K8S中,日志分析工具有哪些可以与K8S集群通讯?
  • 在K8S中,网络通信模式有哪些?
  • 一文教你搞定PASS 2025:样本量计算神器安装到使用全流程
  • React 18.2中采用React Router 6.4
  • 题解:AT_abc257_h [ABC257Ex] Dice Sum 2