Shortcuts

加载Hugging Face上Stable-baselines3的模型

Stable-baseline3 实现了很多强化学习算法,并在 Hugging Face 上分享了训练出来的模型。 OpenRL可以加载这些模型,然后使用OpenRL进行测试和训练。我们在 这里 给出了一个加载Stable-baseline3模型的示例。

环境安装

首先我们需要通过 pip 来安装一些必要的包:

pip install huggingface-tool rl_zoo3

下载模型

安装好 huggingface-tool 工具后,我们可以通过 htool 命令来下来我们需要的模型。 这里,我们以 sb3/ppo-CartPole-v1 为例,其下载命令如下:

htool save-repo sb3/ppo-CartPole-v1 ppo-CartPole-v1

这里 sb3/ppo-CartPole-v1 是模型的地址, ppo-CartPole-v1 是我们保存的模型的名字。 htool 会自动下载模型并保存到 ppo-CartPole-v1 目录下。

加载Stable-baselines3模型并用于测试

模型下载完后,我们便可以用过OpenRL来加载该模型并进行测试,该部分完整代码可见 这里

# test_model.py
import numpy as np
import torch

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common.ppo_net import PPONet as Net
from openrl.modules.networks.policy_value_network_sb3 import (
    PolicyValueNetworkSB3 as PolicyValueNetwork,
)
from openrl.runners.common import PPOAgent as Agent

cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
env = make("CartPole-v1",  env_num=9, asynchronous=True)
model_dict = {"model": PolicyValueNetwork}
net = Net(
    env,
    cfg=cfg,
    model_dict=model_dict,
    device="cuda" if torch.cuda.is_available() else "cpu",
)
agent = Agent(net)
agent.set_env(env)
obs, info = env.reset()
done = False

while not np.any(done):
    action, _ = agent.act(obs, deterministic=True)
    obs, r, done, info = env.step(action)
env.close()

另外还需要写一个配置文件,用于配置模型的路径:

# ppo.yaml
sb3_model_path: ppo-CartPole-v1/ppo-CartPole-v1.zip
sb3_algo: ppo
use_share_model: true

这样,我们就可以通过 python test_model.py 来加载Stable-baselines3的模型并进行测试了。

加载Stable-baselines3模型并用于训练

模型下载完后,我们还可以用过OpenRL来加载该模型并用于训练,该部分完整代码可见 这里

# train_ppo.py
import numpy as np
import torch

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common.ppo_net import PPONet as Net
from openrl.modules.networks.policy_value_network_sb3 import (
    PolicyValueNetworkSB3 as PolicyValueNetwork,
)
from openrl.runners.common import PPOAgent as Agent

cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
env = make("CartPole-v1", env_num=8, asynchronous=True)
model_dict = {"model": PolicyValueNetwork}
net = Net(
    env,
    cfg=cfg,
    model_dict=model_dict,
    device="cuda" if torch.cuda.is_available() else "cpu",
)
agent = Agent(net)
agent.train(total_time_steps=100000)
agent.save("./ppo_sb3_agent")

另外还需要写一个配置文件,用于配置模型的路径和训练超参数:

# ppo.yaml
sb3_model_path: ppo-CartPole-v1/ppo-CartPole-v1.zip
sb3_algo: ppo
use_share_model: true
entropy_coef: 0.0
gae_lambda: 0.8
gamma: 0.98
lr: 0.001
episode_length: 32
ppo_epoch: 20

这样,我们就可以通过 python train_ppo.py 来加载Stable-baselines3的模型并进行训练了。

Read the Docs v: latest
Versions
latest
stable
main
v0.0.13
v0.0.6
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.