Shortcuts

openrl.runners.common package

Submodules

openrl.runners.common.base_agent module

class openrl.runners.common.base_agent.BaseAgent[源代码]

基类:abc.ABC

abstract load(path: Union[str, pathlib.Path, io.BufferedIOBase])[源代码]
abstract save(path: Union[str, pathlib.Path, io.BufferedIOBase]) None[源代码]

openrl.runners.common.chat_agent module

class openrl.runners.common.chat_agent.ChatAgent(model, tokenizer, device=None)[源代码]

基类:openrl.runners.common.base_agent.BaseAgent

chat(input: str, history: List[str])[源代码]
classmethod load(agent_path: Union[str, pathlib.Path, io.BufferedIOBase], tokenizer: Optional[Union[str, pathlib.Path, io.BufferedIOBase]] = None, disable_cuda: Optional[bool] = True) openrl.runners.common.base_agent.SelfAgent[源代码]
save(path: Union[str, pathlib.Path, io.BufferedIOBase]) None[源代码]

openrl.runners.common.ppo_agent module

class openrl.runners.common.ppo_agent.PPOAgent(net: Optional[torch.nn.modules.module.Module] = None, env: Union[gym.core.Env, str] = None, run_dir: Optional[str] = None, env_num: Optional[int] = None, rank: int = 0, world_size: int = 1, use_wandb: bool = False, use_tensorboard: bool = False)[源代码]

基类:openrl.runners.common.base_agent.BaseAgent

act(observation: Union[numpy.ndarray, Dict[str, numpy.ndarray]], deterministic: bool = True) Tuple[numpy.ndarray, Optional[Tuple[numpy.ndarray, ...]]][源代码]
load(path: Union[str, pathlib.Path, io.BufferedIOBase]) None[源代码]
load_policy(path: Union[str, pathlib.Path, io.BufferedIOBase]) None[源代码]
save(path: Union[str, pathlib.Path, io.BufferedIOBase]) None[源代码]
set_env(env: Union[gym.core.Env, str] = None)[源代码]
train(total_time_steps: int) None[源代码]

Module contents

class openrl.runners.common.ChatAgent(model, tokenizer, device=None)[源代码]

基类:openrl.runners.common.base_agent.BaseAgent

chat(input: str, history: List[str])[源代码]
classmethod load(agent_path: Union[str, pathlib.Path, io.BufferedIOBase], tokenizer: Optional[Union[str, pathlib.Path, io.BufferedIOBase]] = None, disable_cuda: Optional[bool] = True) openrl.runners.common.base_agent.SelfAgent[源代码]
save(path: Union[str, pathlib.Path, io.BufferedIOBase]) None[源代码]
class openrl.runners.common.PPOAgent(net: Optional[torch.nn.modules.module.Module] = None, env: Union[gym.core.Env, str] = None, run_dir: Optional[str] = None, env_num: Optional[int] = None, rank: int = 0, world_size: int = 1, use_wandb: bool = False, use_tensorboard: bool = False)[源代码]

基类:openrl.runners.common.base_agent.BaseAgent

act(observation: Union[numpy.ndarray, Dict[str, numpy.ndarray]], deterministic: bool = True) Tuple[numpy.ndarray, Optional[Tuple[numpy.ndarray, ...]]][源代码]
load(path: Union[str, pathlib.Path, io.BufferedIOBase]) None[源代码]
load_policy(path: Union[str, pathlib.Path, io.BufferedIOBase]) None[源代码]
save(path: Union[str, pathlib.Path, io.BufferedIOBase]) None[源代码]
set_env(env: Union[gym.core.Env, str] = None)[源代码]
train(total_time_steps: int) None[源代码]