Shortcuts

Source code for openrl.utils.callbacks.callbacks

# Modified from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/callbacks.py

import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union

import gym
import numpy as np

from openrl.utils.logger import Logger

try:
    from tqdm import TqdmExperimentalWarning

    # Remove experimental warning
    warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
    from tqdm.rich import tqdm
except ImportError:
    # Rich not installed, we only throw an error
    # if the progress bar is used
    tqdm = None

from openrl.envs.vec_env import BaseVecEnv
from openrl.runners.common.base_agent import BaseAgent


[docs]class BaseCallback(ABC): """ Base class for callback. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ # The RL model # Type hint as string to avoid circular import agent: "BaseAgent" logger: Logger def __init__(self, verbose: int = 0): super().__init__() # An alias for self.agent.get_env(), the environment used for training self.training_env = None # type: Union[gym.Env, BaseVecEnv, None] # Number of time the callback was called self.n_calls = 0 # type: int # n_envs * n times env.step() was called self.num_time_steps = 0 # type: int self.verbose = verbose self.locals: Dict[str, Any] = {} self.globals: Dict[str, Any] = {} # Sometimes, for event callback, it is useful # to have access to the parent object self.parent = None # type: Optional[BaseCallback] # Type hint as string to avoid circular import
[docs] def init_callback(self, agent: "BaseAgent") -> None: """ Initialize the callback by saving references to the RL model and the training environment for convenience. """ self.agent = agent self.training_env = agent.get_env() self.logger = agent.logger self._init_callback()
def _init_callback(self) -> None: pass
[docs] def on_training_start( self, locals_: Dict[str, Any], globals_: Dict[str, Any] ) -> None: # Those are reference and will be updated automatically self.locals = locals_ self.globals = globals_ # Update num_timesteps in case training was done before self.num_time_steps = self.agent.num_time_steps self._on_training_start()
def _on_training_start(self) -> None: pass
[docs] def on_rollout_start(self) -> None: self._on_rollout_start()
def _on_rollout_start(self) -> None: pass @abstractmethod def _on_step(self) -> bool: """ :return: If the callback returns False, training is aborted early. """ return True
[docs] def on_step(self) -> bool: """ This method will be called by the model after each call to ``env.step()``. For child callback (of an ``EventCallback``), this will be called when the event is triggered. :return: If the callback returns False, training is aborted early. """ self.n_calls += 1 self.num_time_steps = self.agent.num_time_steps return self._on_step()
[docs] def on_training_end(self) -> None: self._on_training_end()
def _on_training_end(self) -> None: pass
[docs] def on_rollout_end(self) -> None: self._on_rollout_end()
def _on_rollout_end(self) -> None: pass
[docs] def update_locals(self, locals_: Dict[str, Any]) -> None: """ Update the references to the local variables. :param locals_: the local variables during rollout collection """ self.locals.update(locals_) self.update_child_locals(locals_)
[docs] def update_child_locals(self, locals_: Dict[str, Any]) -> None: """ Update the references to the local variables on sub callbacks. :param locals_: the local variables during rollout collection """ pass
[docs]class EventCallback(BaseCallback): """ Base class for triggering callback on event. :param callback: Callback that will be called when an event is triggered. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0): super().__init__(verbose=verbose) self.callback = callback # Give access to the parent if callback is not None: self.callback.parent = self
[docs] def init_callback(self, agent: "BaseAgent") -> None: super().init_callback(agent) if self.callback is not None: self.callback.init_callback(self.agent)
def _on_training_start(self) -> None: if self.callback is not None: self.callback.on_training_start(self.locals, self.globals) def _on_event(self) -> bool: if self.callback is not None: return self.callback.on_step() return True def _on_step(self) -> bool: return True
[docs] def update_child_locals(self, locals_: Dict[str, Any]) -> None: """ Update the references to the local variables. :param locals_: the local variables during rollout collection """ if self.callback is not None: self.callback.update_locals(locals_)
[docs]class CallbackList(BaseCallback): """ Class for chaining callbacks. :param callbacks: A list of callbacks that will be called sequentially. """ def __init__(self, callbacks: List[BaseCallback]): super().__init__() assert isinstance(callbacks, list) self.callbacks = callbacks def _init_callback(self) -> None: for callback in self.callbacks: callback.init_callback(self.agent) def _on_training_start(self) -> None: for callback in self.callbacks: callback.on_training_start(self.locals, self.globals) def _on_rollout_start(self) -> None: for callback in self.callbacks: callback.on_rollout_start() def _on_step(self) -> bool: continue_training = True for callback in self.callbacks: # Return False (stop training) if at least one callback returns False continue_training = callback.on_step() and continue_training return continue_training def _on_rollout_end(self) -> None: for callback in self.callbacks: callback.on_rollout_end() def _on_training_end(self) -> None: for callback in self.callbacks: callback.on_training_end()
[docs] def update_child_locals(self, locals_: Dict[str, Any]) -> None: """ Update the references to the local variables. :param locals_: the local variables during rollout collection """ for callback in self.callbacks: callback.update_locals(locals_)
def __repr__(self): callback_names = [] for callback in self.callbacks: callback_names.append(callback.__class__.__name__) return str(callback_names)
[docs]class ConvertCallback(BaseCallback): """ Convert functional callback (old-style) to object. :param callback: :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ def __init__( self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0, ): super().__init__(verbose) self.callback = callback def _on_step(self) -> bool: if self.callback is not None: return self.callback(self.locals, self.globals) return True
[docs]class StopTrainingOnRewardThreshold(BaseCallback): """ Stop the training once a threshold in episodic reward has been reached (i.e. when the model is good enough). It must be used with the ``EvalCallback``. :param reward_threshold: Minimum expected reward per episode to stop training. :param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because episodic reward threshold reached """ def __init__(self, reward_threshold: float, verbose: int = 0): super().__init__(verbose=verbose) self.reward_threshold = reward_threshold def _on_step(self) -> bool: assert self.parent is not None, ( "``StopTrainingOnMinimumReward`` callback must be used " "with an ``EvalCallback``" ) # Convert np.bool_ to bool, otherwise callback() is False won't work continue_training = bool(self.parent.best_mean_reward < self.reward_threshold) if self.verbose >= 1 and not continue_training: print( "Stopping training because the mean reward" f" {self.parent.best_mean_reward:.2f} is above the threshold" f" {self.reward_threshold}" ) return continue_training
[docs]class EveryNTimesteps(EventCallback): """ Trigger a callback every ``n_steps`` timesteps :param n_steps: Number of timesteps between two trigger. :param callback: Callback that will be called when the event is triggered. """ def __init__(self, n_steps: int, callback: BaseCallback): super().__init__(callback) self.n_steps = n_steps self.last_time_trigger = 0 def _on_step(self) -> bool: if (self.num_time_steps - self.last_time_trigger) >= self.n_steps: self.last_time_trigger = self.num_time_steps return self._on_event() return True
[docs]class StopTrainingOnMaxEpisodes(BaseCallback): """ Stop the training once a maximum number of episodes are played. For multiple environments presumes that, the desired behavior is that the agent trains on each env for ``max_episodes`` and in total for ``max_episodes * n_envs`` episodes. :param max_episodes: Maximum number of episodes to stop training. :param verbose: Verbosity level: 0 for no output, 1 for indicating information about when training ended by reaching ``max_episodes`` """ def __init__(self, max_episodes: int, verbose: int = 0): super().__init__(verbose=verbose) self.max_episodes = max_episodes self._total_max_episodes = max_episodes self.n_episodes = 0 def _init_callback(self) -> None: # At start set total max according to number of envirnments self._total_max_episodes = self.max_episodes * self.training_env.num_envs def _on_step(self) -> bool: # Check that the `dones` local variable is defined assert "dones" in self.locals, ( "`dones` variable is not defined, please check your code next to" " `callback.on_step()`" ) self.n_episodes += np.sum(self.locals["dones"]).item() continue_training = self.n_episodes < self._total_max_episodes if self.verbose >= 1 and not continue_training: mean_episodes_per_env = self.n_episodes / self.training_env.num_envs mean_ep_str = ( f"with an average of {mean_episodes_per_env:.2f} episodes per env" if self.training_env.num_envs > 1 else "" ) print( f"Stopping training with a total of {self.num_time_steps} steps because" f" the {self.locals.get('tb_log_name')} model reached" f" max_episodes={self.max_episodes}, by playing for" f" {self.n_episodes} episodes {mean_ep_str}" ) return continue_training
[docs]class StopTrainingOnNoModelImprovement(BaseCallback): """ Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations. It is possible to define a minimum number of evaluations before start to count evaluations without improvement. It must be used with the ``EvalCallback``. :param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model. :param min_evals: Number of evaluations before start to count evaluations without improvements. :param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because no new best model """ def __init__( self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0 ): super().__init__(verbose=verbose) self.max_no_improvement_evals = max_no_improvement_evals self.min_evals = min_evals self.last_best_mean_reward = -np.inf self.no_improvement_evals = 0 def _on_step(self) -> bool: assert self.parent is not None, ( "``StopTrainingOnNoModelImprovement`` callback must be used with an" " ``EvalCallback``" ) continue_training = True if self.n_calls > self.min_evals: if self.parent.best_mean_reward > self.last_best_mean_reward: self.no_improvement_evals = 0 else: self.no_improvement_evals += 1 if self.no_improvement_evals > self.max_no_improvement_evals: continue_training = False self.last_best_mean_reward = self.parent.best_mean_reward if self.verbose >= 1 and not continue_training: print( "Stopping training because there was no new best model in the last" f" {self.no_improvement_evals:d} evaluations" ) return continue_training
[docs]class ProgressBarCallback(BaseCallback): """ Display a progress bar when training SB3 agent using tqdm and rich packages. """ def __init__(self) -> None: super().__init__() if tqdm is None: raise ImportError( "You must install tqdm and rich in order to use the progress bar" " callback. " ) self.pbar = None def _on_training_start(self) -> None: # Initialize progress bar # Remove timesteps that were done in previous training sessions self.pbar = tqdm( total=self.locals["total_timesteps"] - self.agent.num_time_steps ) def _on_step(self) -> bool: # Update progress bar, we do num_envs steps per call to `env.step()` self.pbar.update(self.training_env.num_envs) return True def _on_training_end(self) -> None: # Flush and close progress bar self.pbar.refresh() self.pbar.close()