[ PROMPT_NODE_27426 ]
stable-baselines3
[ SKILL_DOCUMENTATION ]
# Stable Baselines3
## 概述
Stable Baselines3 (SB3) 是一个基于 PyTorch 的库,提供可靠的强化学习算法实现。此技能为训练强化学习智能体、创建自定义环境、实现回调以及使用 SB3 统一 API 优化训练工作流提供全面指导。
## 核心能力
### 1. 训练强化学习智能体
**基础训练模式:**
python
import gymnasium as gym
from stable_baselines3 import PPO
# 创建环境
env = gym.make("CartPole-v1")
# 初始化智能体
model = PPO("MlpPolicy", env, verbose=1)
# 训练智能体
model.learn(total_timesteps=10000)
# 保存模型
model.save("ppo_cartpole")
# 加载模型(无需预先实例化)
model = PPO.load("ppo_cartpole", env=env)
**重要说明:**
- `total_timesteps` 是下限;由于批处理收集,实际训练可能会超过此值
- 使用 `model.load()` 作为静态方法,而不是在现有实例上调用
- 为节省空间,重放缓冲区(replay buffer)不会随模型一起保存
**算法选择:**
使用 `references/algorithms.md` 获取详细的算法特性和选择指南。快速参考:
- **PPO/A2C**: 通用型,支持所有动作空间类型,适合多进程
- **SAC/TD3**: 连续控制,离策略(off-policy),样本效率高
- **DQN**: 离散动作,离策略
- **HER**: 目标条件任务
查看 `scripts/train_rl_agent.py` 获取包含最佳实践的完整训练模板。
### 2. 自定义环境
**要求:**
自定义环境必须继承自 `gymnasium.Env` 并实现:
- `__init__()`: 定义 action_space 和 observation_space
- `reset(seed, options)`: 返回初始观测值和 info 字典
- `step(action)`: 返回观测值、奖励、terminated、truncated、info
- `render()`: 可视化(可选)
- `close()`: 清理资源
**关键约束:**
- 图像观测值必须是范围在 [0, 255] 的 `np.uint8` 类型
- 尽可能使用通道优先格式(通道,高度,宽度)
- SB3 通过除以 255 自动归一化图像
- 如果已预先归一化,请在 policy_kwargs 中设置 `normalize_images=False`
- SB3 不支持 `start!=0` 的 `Discrete` 或 `MultiDiscrete` 空间
**验证:**
python
from stable_baselines3.common.env_checker import check_env
check_env(env, warn=True)
查看 `scripts/custom_env_template.py` 获取完整的自定义环境模板,并查看 `references/custom_environments.md` 获取全面指导。
### 3. Vecto