Shortcuts

openrl.utils.callbacks package

Submodules

openrl.utils.callbacks.callbacks module

class openrl.utils.callbacks.callbacks.BaseCallback(verbose: int = 0)[源代码]

基类:abc.ABC

Base class for callback.

参数

verbose -- Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages

agent: openrl.runners.common.base_agent.BaseAgent
init_callback(agent: openrl.runners.common.base_agent.BaseAgent) None[源代码]

Initialize the callback by saving references to the RL model and the training environment for convenience.

logger: openrl.utils.logger.Logger
on_rollout_end() None[源代码]
on_rollout_start() None[源代码]
on_step() bool[源代码]

This method will be called by the model after each call to env.step().

For child callback (of an EventCallback), this will be called when the event is triggered.

返回

If the callback returns False, training is aborted early.

on_training_end() None[源代码]
on_training_start(locals_: Dict[str, Any], globals_: Dict[str, Any]) None[源代码]
set_parent(parent: openrl.utils.callbacks.callbacks.BaseCallback) None[源代码]

Set the parent of the callback.

参数

parent -- The parent callback.

update_child_locals(locals_: Dict[str, Any]) None[源代码]

Update the references to the local variables on sub callbacks.

参数

locals -- the local variables during rollout collection

update_locals(locals_: Dict[str, Any]) None[源代码]

Update the references to the local variables.

参数

locals -- the local variables during rollout collection

class openrl.utils.callbacks.callbacks.CallbackList(callbacks: List[openrl.utils.callbacks.callbacks.BaseCallback], stop_logic: str = 'OR')[源代码]

基类:openrl.utils.callbacks.callbacks.BaseCallback

Class for chaining callbacks.

参数

callbacks -- A list of callbacks that will be called sequentially.

set_parent(parent: openrl.utils.callbacks.callbacks.BaseCallback) None[源代码]

Set the parent of the callback.

参数

parent -- The parent callback.

update_child_locals(locals_: Dict[str, Any]) None[源代码]

Update the references to the local variables.

参数

locals -- the local variables during rollout collection

class openrl.utils.callbacks.callbacks.ConvertCallback(callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0)[源代码]

基类:openrl.utils.callbacks.callbacks.BaseCallback

Convert functional callback (old-style) to object.

参数
  • callback --

  • verbose -- Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages

class openrl.utils.callbacks.callbacks.EventCallback(callback: Optional[openrl.utils.callbacks.callbacks.BaseCallback] = None, verbose: int = 0)[源代码]

基类:openrl.utils.callbacks.callbacks.BaseCallback

Base class for triggering callback on event.

参数
  • callback -- Callback that will be called when an event is triggered.

  • verbose -- Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages

init_callback(agent: openrl.runners.common.base_agent.BaseAgent) None[源代码]

Initialize the callback by saving references to the RL model and the training environment for convenience.

update_child_locals(locals_: Dict[str, Any]) None[源代码]

Update the references to the local variables.

参数

locals -- the local variables during rollout collection

class openrl.utils.callbacks.callbacks.EveryNTimesteps(n_steps: int, callbacks: Union[List[Dict[str, Any]], Dict[str, Any], openrl.utils.callbacks.callbacks.BaseCallback], stop_logic: str = 'OR')[源代码]

基类:openrl.utils.callbacks.callbacks.EventCallback

Trigger a callback every n_steps timesteps

参数
  • n_steps -- Number of timesteps between two trigger.

  • callback -- Callback that will be called when the event is triggered.

openrl.utils.callbacks.callbacks_factory module

class openrl.utils.callbacks.callbacks_factory.CallbackFactory[源代码]

基类:object

static get_callback(callback: Dict[str, Any]) openrl.utils.callbacks.callbacks.BaseCallback[源代码]
static get_callbacks(callbacks: Union[Dict[str, Any], List[Dict[str, Any]]], stop_logic: str = 'OR') openrl.utils.callbacks.callbacks.CallbackList[源代码]
static register(id: str, callback_class: Type[openrl.utils.callbacks.callbacks.BaseCallback])[源代码]

openrl.utils.callbacks.checkpoint_callback module

class openrl.utils.callbacks.checkpoint_callback.CheckpointCallback(save_freq: int, save_path: Union[str, pathlib.Path], name_prefix: str = 'rl_model', save_replay_buffer: bool = False, verbose: int = 0)[源代码]

基类:openrl.utils.callbacks.callbacks.BaseCallback

Callback for saving a model every save_freq calls to env.step(). By default, it only saves model checkpoints, you need to pass save_replay_buffer=True to save replay buffer checkpoints.

警告

When using multiple environments, each call to env.step() will effectively correspond to n_envs steps. To account for that, you can use save_freq = max(save_freq // n_envs, 1)

参数
  • save_freq -- Save checkpoints every save_freq call of the callback.

  • save_path -- Path to the folder where the model will be saved.

  • name_prefix -- Common prefix to the saved models

  • save_replay_buffer -- Save the model replay buffer

  • verbose -- Verbosity level: 0 for no output, 2 for indicating when saving model checkpoint

openrl.utils.callbacks.eval_callback module

class openrl.utils.callbacks.eval_callback.EvalCallback(eval_env: Union[str, Dict[str, Any], gymnasium.core.Env, openrl.envs.vec_env.base_venv.BaseVecEnv], callbacks_on_new_best: Optional[Union[List[Dict[str, Any]], Dict[str, Any], openrl.utils.callbacks.callbacks.BaseCallback]] = None, callbacks_after_eval: Optional[Union[List[Dict[str, Any]], Dict[str, Any], openrl.utils.callbacks.callbacks.BaseCallback]] = None, n_eval_episodes: int = 5, eval_freq: int = 10000, log_path: Optional[Union[str, pathlib.Path]] = None, best_model_save_path: Optional[Union[str, pathlib.Path]] = None, deterministic: bool = True, render: bool = False, asynchronous: bool = True, verbose: int = 1, warn: bool = True, stop_logic: str = 'OR', close_env_at_end: bool = True)[源代码]

基类:openrl.utils.callbacks.callbacks.EventCallback

Callback for evaluating an agent.

警告

When using multiple environments, each call to env.step() will effectively correspond to n_envs steps. To account for that, you can use eval_freq = max(eval_freq // n_envs, 1)

参数
  • eval_env -- The environment used for initialization

  • callback_on_new_best -- Callback to trigger when there is a new best model according to the mean_reward

  • callbacks_after_eval -- Callback to trigger after every evaluation

  • n_eval_episodes -- The number of episodes to test the agent

  • eval_freq -- Evaluate the agent every eval_freq call of the callback.

  • log_path -- Path to a folder where the evaluations (evaluations.npz) will be saved. It will be updated at each evaluation.

  • best_model_save_path -- Path to a folder where the best model according to performance on the eval env will be saved.

  • deterministic -- Whether the evaluation should use a stochastic or deterministic actions.

  • render -- Whether to render or not the environment during evaluation

  • verbose -- Verbosity level: 0 for no output, 1 for indicating information about evaluation results

  • warn -- Passed to evaluate_policy (warns if eval_env has not been wrapped with a Monitor wrapper)

update_child_locals(locals_: Dict[str, Any]) None[源代码]

Update the references to the local variables.

参数

locals -- the local variables during rollout collection

openrl.utils.callbacks.processbar_callback module

class openrl.utils.callbacks.processbar_callback.ProgressBarCallback[源代码]

基类:openrl.utils.callbacks.callbacks.BaseCallback

Display a progress bar when training SB3 agent using tqdm and rich packages.

openrl.utils.callbacks.stop_callback module

class openrl.utils.callbacks.stop_callback.StopTrainingOnMaxEpisodes(max_episodes: int, verbose: int = 0)[源代码]

基类:openrl.utils.callbacks.callbacks.BaseCallback

Stop the training once a maximum number of episodes are played.

For multiple environments presumes that, the desired behavior is that the agent trains on each env for max_episodes and in total for max_episodes * n_envs episodes.

参数
  • max_episodes -- Maximum number of episodes to stop training.

  • verbose -- Verbosity level: 0 for no output, 1 for indicating information about when training ended by reaching max_episodes

class openrl.utils.callbacks.stop_callback.StopTrainingOnNoModelImprovement(max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 1)[源代码]

基类:openrl.utils.callbacks.callbacks.BaseCallback

Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.

It is possible to define a minimum number of evaluations before start to count evaluations without improvement.

It must be used with the EvalCallback.

参数
  • max_no_improvement_evals -- Maximum number of consecutive evaluations without a new best model.

  • min_evals -- Number of evaluations before start to count evaluations without improvements.

  • verbose -- Verbosity level: 0 for no output, 1 for indicating when training ended because no new best model

class openrl.utils.callbacks.stop_callback.StopTrainingOnRewardThreshold(reward_threshold: float, verbose: int = 0)[源代码]

基类:openrl.utils.callbacks.callbacks.BaseCallback

Stop the training once a threshold in episodic reward has been reached (i.e. when the model is good enough).

It must be used with the EvalCallback.

参数
  • reward_threshold -- Minimum expected reward per episode to stop training.

  • verbose -- Verbosity level: 0 for no output, 1 for indicating when training ended because episodic reward threshold reached

Module contents

class openrl.utils.callbacks.CallbackFactory[源代码]

基类:object

static get_callback(callback: Dict[str, Any]) openrl.utils.callbacks.callbacks.BaseCallback[源代码]
static get_callbacks(callbacks: Union[Dict[str, Any], List[Dict[str, Any]]], stop_logic: str = 'OR') openrl.utils.callbacks.callbacks.CallbackList[源代码]
static register(id: str, callback_class: Type[openrl.utils.callbacks.callbacks.BaseCallback])[源代码]