Source code for openrl.utils.callbacks.callbacks_factory
from typing import Any, Dict, List, Type, Union
from openrl.selfplay.callbacks.selfplay_callback import SelfplayCallback
from openrl.utils.callbacks.callbacks import BaseCallback, CallbackList, EveryNTimesteps
from openrl.utils.callbacks.checkpoint_callback import CheckpointCallback
from openrl.utils.callbacks.eval_callback import EvalCallback
from openrl.utils.callbacks.processbar_callback import ProgressBarCallback
from openrl.utils.callbacks.stop_callback import (
StopTrainingOnMaxEpisodes,
StopTrainingOnNoModelImprovement,
StopTrainingOnRewardThreshold,
)
callbacks_dict = {
"CheckpointCallback": CheckpointCallback,
"EvalCallback": EvalCallback,
"StopTrainingOnRewardThreshold": StopTrainingOnRewardThreshold,
"StopTrainingOnMaxEpisodes": StopTrainingOnMaxEpisodes,
"StopTrainingOnNoModelImprovement": StopTrainingOnNoModelImprovement,
"ProgressBarCallback": ProgressBarCallback,
"EveryNTimesteps": EveryNTimesteps,
"SelfplayCallback": SelfplayCallback,
}
[docs]class CallbackFactory:
[docs] @staticmethod
def get_callback(
callback: Dict[str, Any],
) -> BaseCallback:
if callback["id"] == "SelfplayAPI" and "SelfplayAPI" not in callbacks_dict:
from openrl.selfplay.callbacks.selfplay_api import SelfplayAPI
callbacks_dict["SelfplayAPI"] = SelfplayAPI
if callback["id"] not in callbacks_dict:
raise ValueError(f"Callback {callback['id']} not found")
if "args" in callback:
callback = callbacks_dict[callback["id"]](**callback["args"])
else:
callback = callbacks_dict[callback["id"]]()
return callback
[docs] @staticmethod
def get_callbacks(
callbacks: Union[Dict[str, Any], List[Dict[str, Any]]],
stop_logic: str = "OR",
) -> CallbackList:
if isinstance(callbacks, dict):
callbacks = [callbacks]
callbacks_list = []
for callback in callbacks:
if callback["id"] == "SelfplayAPI" and "SelfplayAPI" not in callbacks_dict:
from openrl.selfplay.callbacks.selfplay_api import SelfplayAPI
callbacks_dict["SelfplayAPI"] = SelfplayAPI
if callback["id"] not in callbacks_dict:
raise ValueError(f"Callback {callback['id']} not found")
callbacks_list.append(CallbackFactory.get_callback(callback))
return CallbackList(callbacks_list, stop_logic=stop_logic)
[docs] @staticmethod
def register(
id: str,
callback_class: Type[BaseCallback],
):
callbacks_dict[id] = callback_class