Source code for openrl.utils.callbacks.callbacks
# Modified from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/callbacks.py
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
import gym
import numpy as np
from openrl.utils.logger import Logger
try:
from tqdm import TqdmExperimentalWarning
# Remove experimental warning
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
from tqdm.rich import tqdm
except ImportError:
# Rich not installed, we only throw an error
# if the progress bar is used
tqdm = None
from openrl.envs.vec_env import BaseVecEnv
from openrl.runners.common.base_agent import BaseAgent
[docs]class BaseCallback(ABC):
"""
Base class for callback.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
# The RL model
# Type hint as string to avoid circular import
agent: "BaseAgent"
logger: Logger
def __init__(self, verbose: int = 0):
super().__init__()
# An alias for self.agent.get_env(), the environment used for training
self.training_env = None # type: Union[gym.Env, BaseVecEnv, None]
# Number of time the callback was called
self.n_calls = 0 # type: int
# n_envs * n times env.step() was called
self.num_time_steps = 0 # type: int
self.verbose = verbose
self.locals: Dict[str, Any] = {}
self.globals: Dict[str, Any] = {}
# Sometimes, for event callback, it is useful
# to have access to the parent object
self.parent = None # type: Optional[BaseCallback]
# Type hint as string to avoid circular import
[docs] def init_callback(self, agent: "BaseAgent") -> None:
"""
Initialize the callback by saving references to the
RL model and the training environment for convenience.
"""
self.agent = agent
self.training_env = agent.get_env()
self.logger = agent.logger
self._init_callback()
def _init_callback(self) -> None:
pass
[docs] def on_training_start(
self, locals_: Dict[str, Any], globals_: Dict[str, Any]
) -> None:
# Those are reference and will be updated automatically
self.locals = locals_
self.globals = globals_
# Update num_timesteps in case training was done before
self.num_time_steps = self.agent.num_time_steps
self._on_training_start()
def _on_training_start(self) -> None:
pass
def _on_rollout_start(self) -> None:
pass
@abstractmethod
def _on_step(self) -> bool:
"""
:return: If the callback returns False, training is aborted early.
"""
return True
[docs] def on_step(self) -> bool:
"""
This method will be called by the model after each call to ``env.step()``.
For child callback (of an ``EventCallback``), this will be called
when the event is triggered.
:return: If the callback returns False, training is aborted early.
"""
self.n_calls += 1
self.num_time_steps = self.agent.num_time_steps
return self._on_step()
def _on_training_end(self) -> None:
pass
def _on_rollout_end(self) -> None:
pass
[docs] def update_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
self.locals.update(locals_)
self.update_child_locals(locals_)
[docs] def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables on sub callbacks.
:param locals_: the local variables during rollout collection
"""
pass
[docs]class EventCallback(BaseCallback):
"""
Base class for triggering callback on event.
:param callback: Callback that will be called
when an event is triggered.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
super().__init__(verbose=verbose)
self.callback = callback
# Give access to the parent
if callback is not None:
self.callback.parent = self
[docs] def init_callback(self, agent: "BaseAgent") -> None:
super().init_callback(agent)
if self.callback is not None:
self.callback.init_callback(self.agent)
def _on_training_start(self) -> None:
if self.callback is not None:
self.callback.on_training_start(self.locals, self.globals)
def _on_event(self) -> bool:
if self.callback is not None:
return self.callback.on_step()
return True
def _on_step(self) -> bool:
return True
[docs] def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
if self.callback is not None:
self.callback.update_locals(locals_)
[docs]class CallbackList(BaseCallback):
"""
Class for chaining callbacks.
:param callbacks: A list of callbacks that will be called
sequentially.
"""
def __init__(self, callbacks: List[BaseCallback]):
super().__init__()
assert isinstance(callbacks, list)
self.callbacks = callbacks
def _init_callback(self) -> None:
for callback in self.callbacks:
callback.init_callback(self.agent)
def _on_training_start(self) -> None:
for callback in self.callbacks:
callback.on_training_start(self.locals, self.globals)
def _on_rollout_start(self) -> None:
for callback in self.callbacks:
callback.on_rollout_start()
def _on_step(self) -> bool:
continue_training = True
for callback in self.callbacks:
# Return False (stop training) if at least one callback returns False
continue_training = callback.on_step() and continue_training
return continue_training
def _on_rollout_end(self) -> None:
for callback in self.callbacks:
callback.on_rollout_end()
def _on_training_end(self) -> None:
for callback in self.callbacks:
callback.on_training_end()
[docs] def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
for callback in self.callbacks:
callback.update_locals(locals_)
def __repr__(self):
callback_names = []
for callback in self.callbacks:
callback_names.append(callback.__class__.__name__)
return str(callback_names)
[docs]class ConvertCallback(BaseCallback):
"""
Convert functional callback (old-style) to object.
:param callback:
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(
self,
callback: Callable[[Dict[str, Any], Dict[str, Any]], bool],
verbose: int = 0,
):
super().__init__(verbose)
self.callback = callback
def _on_step(self) -> bool:
if self.callback is not None:
return self.callback(self.locals, self.globals)
return True
[docs]class StopTrainingOnRewardThreshold(BaseCallback):
"""
Stop the training once a threshold in episodic reward
has been reached (i.e. when the model is good enough).
It must be used with the ``EvalCallback``.
:param reward_threshold: Minimum expected reward per episode
to stop training.
:param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because episodic reward
threshold reached
"""
def __init__(self, reward_threshold: float, verbose: int = 0):
super().__init__(verbose=verbose)
self.reward_threshold = reward_threshold
def _on_step(self) -> bool:
assert self.parent is not None, (
"``StopTrainingOnMinimumReward`` callback must be used "
"with an ``EvalCallback``"
)
# Convert np.bool_ to bool, otherwise callback() is False won't work
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
if self.verbose >= 1 and not continue_training:
print(
"Stopping training because the mean reward"
f" {self.parent.best_mean_reward:.2f} is above the threshold"
f" {self.reward_threshold}"
)
return continue_training
[docs]class EveryNTimesteps(EventCallback):
"""
Trigger a callback every ``n_steps`` timesteps
:param n_steps: Number of timesteps between two trigger.
:param callback: Callback that will be called
when the event is triggered.
"""
def __init__(self, n_steps: int, callback: BaseCallback):
super().__init__(callback)
self.n_steps = n_steps
self.last_time_trigger = 0
def _on_step(self) -> bool:
if (self.num_time_steps - self.last_time_trigger) >= self.n_steps:
self.last_time_trigger = self.num_time_steps
return self._on_event()
return True
[docs]class StopTrainingOnMaxEpisodes(BaseCallback):
"""
Stop the training once a maximum number of episodes are played.
For multiple environments presumes that, the desired behavior is that the agent trains on each env for ``max_episodes``
and in total for ``max_episodes * n_envs`` episodes.
:param max_episodes: Maximum number of episodes to stop training.
:param verbose: Verbosity level: 0 for no output, 1 for indicating information about when training ended by
reaching ``max_episodes``
"""
def __init__(self, max_episodes: int, verbose: int = 0):
super().__init__(verbose=verbose)
self.max_episodes = max_episodes
self._total_max_episodes = max_episodes
self.n_episodes = 0
def _init_callback(self) -> None:
# At start set total max according to number of envirnments
self._total_max_episodes = self.max_episodes * self.training_env.num_envs
def _on_step(self) -> bool:
# Check that the `dones` local variable is defined
assert "dones" in self.locals, (
"`dones` variable is not defined, please check your code next to"
" `callback.on_step()`"
)
self.n_episodes += np.sum(self.locals["dones"]).item()
continue_training = self.n_episodes < self._total_max_episodes
if self.verbose >= 1 and not continue_training:
mean_episodes_per_env = self.n_episodes / self.training_env.num_envs
mean_ep_str = (
f"with an average of {mean_episodes_per_env:.2f} episodes per env"
if self.training_env.num_envs > 1
else ""
)
print(
f"Stopping training with a total of {self.num_time_steps} steps because"
f" the {self.locals.get('tb_log_name')} model reached"
f" max_episodes={self.max_episodes}, by playing for"
f" {self.n_episodes} episodes {mean_ep_str}"
)
return continue_training
[docs]class StopTrainingOnNoModelImprovement(BaseCallback):
"""
Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.
It is possible to define a minimum number of evaluations before start to count evaluations without improvement.
It must be used with the ``EvalCallback``.
:param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model.
:param min_evals: Number of evaluations before start to count evaluations without improvements.
:param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because no new best model
"""
def __init__(
self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0
):
super().__init__(verbose=verbose)
self.max_no_improvement_evals = max_no_improvement_evals
self.min_evals = min_evals
self.last_best_mean_reward = -np.inf
self.no_improvement_evals = 0
def _on_step(self) -> bool:
assert self.parent is not None, (
"``StopTrainingOnNoModelImprovement`` callback must be used with an"
" ``EvalCallback``"
)
continue_training = True
if self.n_calls > self.min_evals:
if self.parent.best_mean_reward > self.last_best_mean_reward:
self.no_improvement_evals = 0
else:
self.no_improvement_evals += 1
if self.no_improvement_evals > self.max_no_improvement_evals:
continue_training = False
self.last_best_mean_reward = self.parent.best_mean_reward
if self.verbose >= 1 and not continue_training:
print(
"Stopping training because there was no new best model in the last"
f" {self.no_improvement_evals:d} evaluations"
)
return continue_training
[docs]class ProgressBarCallback(BaseCallback):
"""
Display a progress bar when training SB3 agent
using tqdm and rich packages.
"""
def __init__(self) -> None:
super().__init__()
if tqdm is None:
raise ImportError(
"You must install tqdm and rich in order to use the progress bar"
" callback. "
)
self.pbar = None
def _on_training_start(self) -> None:
# Initialize progress bar
# Remove timesteps that were done in previous training sessions
self.pbar = tqdm(
total=self.locals["total_timesteps"] - self.agent.num_time_steps
)
def _on_step(self) -> bool:
# Update progress bar, we do num_envs steps per call to `env.step()`
self.pbar.update(self.training_env.num_envs)
return True
def _on_training_end(self) -> None:
# Flush and close progress bar
self.pbar.refresh()
self.pbar.close()