Shortcuts

openrl.utils.callbacks package

Submodules

openrl.utils.callbacks.callbacks module

class openrl.utils.callbacks.callbacks.BaseCallback(verbose: int = 0)[source]

Bases: abc.ABC

Base class for callback.

Parameters

verbose – Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages

agent: openrl.runners.common.base_agent.BaseAgent
init_callback(agent: openrl.runners.common.base_agent.BaseAgent) None[source]

Initialize the callback by saving references to the RL model and the training environment for convenience.

logger: openrl.utils.logger.Logger
on_rollout_end() None[source]
on_rollout_start() None[source]
on_step() bool[source]

This method will be called by the model after each call to env.step().

For child callback (of an EventCallback), this will be called when the event is triggered.

Returns

If the callback returns False, training is aborted early.

on_training_end() None[source]
on_training_start(locals_: Dict[str, Any], globals_: Dict[str, Any]) None[source]
set_parent(parent: openrl.utils.callbacks.callbacks.BaseCallback) None[source]

Set the parent of the callback.

Parameters

parent – The parent callback.

update_child_locals(locals_: Dict[str, Any]) None[source]

Update the references to the local variables on sub callbacks.

Parameters

locals – the local variables during rollout collection

update_locals(locals_: Dict[str, Any]) None[source]

Update the references to the local variables.

Parameters

locals – the local variables during rollout collection

class openrl.utils.callbacks.callbacks.CallbackList(callbacks: List[openrl.utils.callbacks.callbacks.BaseCallback], stop_logic: str = 'OR')[source]

Bases: openrl.utils.callbacks.callbacks.BaseCallback

Class for chaining callbacks.

Parameters

callbacks – A list of callbacks that will be called sequentially.

set_parent(parent: openrl.utils.callbacks.callbacks.BaseCallback) None[source]

Set the parent of the callback.

Parameters

parent – The parent callback.

update_child_locals(locals_: Dict[str, Any]) None[source]

Update the references to the local variables.

Parameters

locals – the local variables during rollout collection

class openrl.utils.callbacks.callbacks.ConvertCallback(callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0)[source]

Bases: openrl.utils.callbacks.callbacks.BaseCallback

Convert functional callback (old-style) to object.

Parameters
  • callback

  • verbose – Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages

class openrl.utils.callbacks.callbacks.EventCallback(callback: Optional[openrl.utils.callbacks.callbacks.BaseCallback] = None, verbose: int = 0)[source]

Bases: openrl.utils.callbacks.callbacks.BaseCallback

Base class for triggering callback on event.

Parameters
  • callback – Callback that will be called when an event is triggered.

  • verbose – Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages

init_callback(agent: openrl.runners.common.base_agent.BaseAgent) None[source]

Initialize the callback by saving references to the RL model and the training environment for convenience.

update_child_locals(locals_: Dict[str, Any]) None[source]

Update the references to the local variables.

Parameters

locals – the local variables during rollout collection

class openrl.utils.callbacks.callbacks.EveryNTimesteps(n_steps: int, callbacks: Union[List[Dict[str, Any]], Dict[str, Any], openrl.utils.callbacks.callbacks.BaseCallback], stop_logic: str = 'OR')[source]

Bases: openrl.utils.callbacks.callbacks.EventCallback

Trigger a callback every n_steps timesteps

Parameters
  • n_steps – Number of timesteps between two trigger.

  • callback – Callback that will be called when the event is triggered.

openrl.utils.callbacks.callbacks_factory module

class openrl.utils.callbacks.callbacks_factory.CallbackFactory[source]

Bases: object

static get_callback(callback: Dict[str, Any]) openrl.utils.callbacks.callbacks.BaseCallback[source]
static get_callbacks(callbacks: Union[Dict[str, Any], List[Dict[str, Any]]], stop_logic: str = 'OR') openrl.utils.callbacks.callbacks.CallbackList[source]
static register(id: str, callback_class: Type[openrl.utils.callbacks.callbacks.BaseCallback])[source]

openrl.utils.callbacks.checkpoint_callback module

class openrl.utils.callbacks.checkpoint_callback.CheckpointCallback(save_freq: int, save_path: Union[str, pathlib.Path], name_prefix: str = 'rl_model', save_replay_buffer: bool = False, verbose: int = 0)[source]

Bases: openrl.utils.callbacks.callbacks.BaseCallback

Callback for saving a model every save_freq calls to env.step(). By default, it only saves model checkpoints, you need to pass save_replay_buffer=True to save replay buffer checkpoints.

Warning

When using multiple environments, each call to env.step() will effectively correspond to n_envs steps. To account for that, you can use save_freq = max(save_freq // n_envs, 1)

Parameters
  • save_freq – Save checkpoints every save_freq call of the callback.

  • save_path – Path to the folder where the model will be saved.

  • name_prefix – Common prefix to the saved models

  • save_replay_buffer – Save the model replay buffer

  • verbose – Verbosity level: 0 for no output, 2 for indicating when saving model checkpoint

openrl.utils.callbacks.eval_callback module

class openrl.utils.callbacks.eval_callback.EvalCallback(eval_env: Union[str, Dict[str, Any], gymnasium.core.Env, openrl.envs.vec_env.base_venv.BaseVecEnv], callbacks_on_new_best: Optional[Union[List[Dict[str, Any]], Dict[str, Any], openrl.utils.callbacks.callbacks.BaseCallback]] = None, callbacks_after_eval: Optional[Union[List[Dict[str, Any]], Dict[str, Any], openrl.utils.callbacks.callbacks.BaseCallback]] = None, n_eval_episodes: int = 5, eval_freq: int = 10000, log_path: Optional[Union[str, pathlib.Path]] = None, best_model_save_path: Optional[Union[str, pathlib.Path]] = None, deterministic: bool = True, render: bool = False, asynchronous: bool = True, verbose: int = 1, warn: bool = True, stop_logic: str = 'OR', close_env_at_end: bool = True)[source]

Bases: openrl.utils.callbacks.callbacks.EventCallback

Callback for evaluating an agent.

Warning

When using multiple environments, each call to env.step() will effectively correspond to n_envs steps. To account for that, you can use eval_freq = max(eval_freq // n_envs, 1)

Parameters
  • eval_env – The environment used for initialization

  • callback_on_new_best – Callback to trigger when there is a new best model according to the mean_reward

  • callbacks_after_eval – Callback to trigger after every evaluation

  • n_eval_episodes – The number of episodes to test the agent

  • eval_freq – Evaluate the agent every eval_freq call of the callback.

  • log_path – Path to a folder where the evaluations (evaluations.npz) will be saved. It will be updated at each evaluation.

  • best_model_save_path – Path to a folder where the best model according to performance on the eval env will be saved.

  • deterministic – Whether the evaluation should use a stochastic or deterministic actions.

  • render – Whether to render or not the environment during evaluation

  • verbose – Verbosity level: 0 for no output, 1 for indicating information about evaluation results

  • warn – Passed to evaluate_policy (warns if eval_env has not been wrapped with a Monitor wrapper)

update_child_locals(locals_: Dict[str, Any]) None[source]

Update the references to the local variables.

Parameters

locals – the local variables during rollout collection

openrl.utils.callbacks.processbar_callback module

class openrl.utils.callbacks.processbar_callback.ProgressBarCallback[source]

Bases: openrl.utils.callbacks.callbacks.BaseCallback

Display a progress bar when training SB3 agent using tqdm and rich packages.

openrl.utils.callbacks.stop_callback module

class openrl.utils.callbacks.stop_callback.StopTrainingOnMaxEpisodes(max_episodes: int, verbose: int = 0)[source]

Bases: openrl.utils.callbacks.callbacks.BaseCallback

Stop the training once a maximum number of episodes are played.

For multiple environments presumes that, the desired behavior is that the agent trains on each env for max_episodes and in total for max_episodes * n_envs episodes.

Parameters
  • max_episodes – Maximum number of episodes to stop training.

  • verbose – Verbosity level: 0 for no output, 1 for indicating information about when training ended by reaching max_episodes

class openrl.utils.callbacks.stop_callback.StopTrainingOnNoModelImprovement(max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 1)[source]

Bases: openrl.utils.callbacks.callbacks.BaseCallback

Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.

It is possible to define a minimum number of evaluations before start to count evaluations without improvement.

It must be used with the EvalCallback.

Parameters
  • max_no_improvement_evals – Maximum number of consecutive evaluations without a new best model.

  • min_evals – Number of evaluations before start to count evaluations without improvements.

  • verbose – Verbosity level: 0 for no output, 1 for indicating when training ended because no new best model

class openrl.utils.callbacks.stop_callback.StopTrainingOnRewardThreshold(reward_threshold: float, verbose: int = 0)[source]

Bases: openrl.utils.callbacks.callbacks.BaseCallback

Stop the training once a threshold in episodic reward has been reached (i.e. when the model is good enough).

It must be used with the EvalCallback.

Parameters
  • reward_threshold – Minimum expected reward per episode to stop training.

  • verbose – Verbosity level: 0 for no output, 1 for indicating when training ended because episodic reward threshold reached

Module contents

class openrl.utils.callbacks.CallbackFactory[source]

Bases: object

static get_callback(callback: Dict[str, Any]) openrl.utils.callbacks.callbacks.BaseCallback[source]
static get_callbacks(callbacks: Union[Dict[str, Any], List[Dict[str, Any]]], stop_logic: str = 'OR') openrl.utils.callbacks.callbacks.CallbackList[source]
static register(id: str, callback_class: Type[openrl.utils.callbacks.callbacks.BaseCallback])[source]