Stable-Baselines3
基于 PyTorch 的可靠强化学习算法实现,开箱即用。
可靠稳定
经过严格测试,复现了论文中的性能。如果你跑不通,多半不是算法的问题。
文档丰富
拥有业内最好的文档和教程,对新手极其友好。
API 简洁
统一的接口设计,切换算法只需改一行代码。
3行代码训练 PPO
这就是 SB3 的魅力。你不需要手写 Buffer,不需要手写 Loss,只需要像调包侠一样调用它。
train_sb3.py
import gymnasium as gym from stable_baselines3 import PPO # 1. 创建环境 env = gym.make("CartPole-v1") # 2. 实例化模型 (MlpPolicy 表示使用多层感知机) model = PPO("MlpPolicy", env, verbose=1) # 3. 开始训练 model.learn(total_timesteps=10000) # --- 享受成果 --- obs, _ = env.reset() for _ in range(1000): # deterministic=True 表示使用确定性策略(不加噪声) action, _ = model.predict(obs, deterministic=True) obs, reward, done, truncated, info = env.step(action) env.render() if done or truncated: obs, _ = env.reset()