当前位置: 首页 > news >正文

DRL模型训练:原始奖励函数记录以及绘制

一些参考图片:

image

image

1. 使用sb3库,

调用callback,会记录每个episode结束时的reward;

使用tensorboard记录的rollout/ep_rew_mean,会自动每4个ep平均,并进行平滑,得到的不是原始数据。

from stable_baselines3.common.callbacks import BaseCallback
import os
import numpy as np
class RewardLoggingCallback(BaseCallback):def __init__(self, save_path, verbose=0):super().__init__(verbose)self.save_path = save_pathself.episode_rewards = []def _on_step(self) -> bool:# SB3 会在 episode 结束时把 episode info 放在 infos 中if len(self.locals.get("infos", [])) > 0:for info in self.locals["infos"]:if "episode" in info.keys():self.episode_rewards.append(info["episode"]["r"])return Truedef _on_training_end(self) -> None:os.makedirs(os.path.dirname(self.save_path), exist_ok=True)np.save(self.save_path, np.array(self.episode_rewards))if self.verbose > 0:print(f"Saved episodic rewards to {self.save_path}")

2.调用seaborn库

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd# 假设你通过 callback 保存的数据是多个实验 run 的结果
# 例如保存成: run1_rewards.npy, run2_rewards.npy, ...
files = [
'run1_rewards.npy',
]# 定义滑动平均函数
def moving_average(x, window=50):return np.convolve(x, np.ones(window)/window, mode="valid")# 收集所有数据
data = []
for run_id, f in enumerate(files):rewards = np.load(f)smoothed = moving_average(rewards, window=20)for i, r in enumerate(smoothed):data.append({"timestep": i, "reward": r, "run": run_id})df = pd.DataFrame(data)# seaborn 绘制:均值曲线 + 阴影表示方差区间
plt.figure(figsize=(8, 5))
sns.lineplot(data=df,x="timestep",y="reward",hue=None,estimator="mean",errorbar="sd"  # 可选 "ci" 表示置信区间,"sd" 表示标准差
)plt.title("Episode Reward (Smoothed, Multiple Runs)")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.tight_layout()
plt.show()

参考

https://zhuanlan.zhihu.com/p/635706668
https://www.deeprlhub.com/d/114
https://zhuanlan.zhihu.com/p/75477750

http://www.hskmm.com/?act=detail&tid=19785

相关文章:

  • 中国DevOps平台竞品分析:安全合规与技术生态的双重较量
  • experiment 1
  • 图领域的METIS算法介绍 - zhang
  • CANOpen safety SRDO相关问题总结
  • Prometheus源码专题【左扬精讲】—— 监控系统 Prometheus 3.4.0 源码解析:head_wal.go 的 WAL 写入策略与缓存管理源码解读
  • 电子通信词汇中英文对照
  • 平衡树
  • 完整教程:【有源码】基于Hadoop+Spark的AI就业影响数据分析与可视化系统-AI驱动下的就业市场变迁数据分析与可视化研究-基于大数据的AI就业趋势分析可视化平台
  • Tomcat中启用h3的方法是什么
  • k8s-Namespace
  • 国产化Excel开发组件Spire.XLS教程:C# 写入 Excel ,轻松将数据导出到工作表
  • 牛客刷题-Day6
  • 数字化转型浪潮下:10款主流项目管理工具横向测评与选型指南
  • 借助Aspose.Email,使用 Python 将 EML 转换为 MHTML
  • python+springboot+django/flask的医院食堂订餐系统 菜单发布 在线订餐 餐品管理与订单统计系统 - 教程
  • 计算机网络学习笔记 - 浪矢
  • 数据结构以及LeetCode常用方法 - 浪矢
  • App Store 上架完整流程解析,iOS 应用发布步骤、ipa 文件上传工具、TestFlight 测试与苹果审核经验
  • 使用 Zig 编写英文数字验证码识别工具
  • 数数学习笔记
  • 6 个替代 Microsoft Access 的开源数据库工具推荐
  • 20250626_黔西南网信杯_wireshark
  • Ubuntu STA+AP 开机自启完整方案
  • PDE和CFD的区别?
  • Gitee:中国开发者生态的基石与数字化转型的加速器
  • 20号胶
  • MQTT协议
  • 完整教程:带你了解STM32:TIM定时器(第四部分)
  • 邮件怎么发送超大附件的实用解决方案
  • 告别无效对话:五个让AI输出质量提升10倍的提示词框架