Source code for openrl.utils.callbacks.callbacks_factory
from typing import Any, Dict, List, Type, Union
from openrl.utils.callbacks.callbacks import BaseCallback, CallbackList
from openrl.utils.callbacks.checkpoint_callback import CheckpointCallback
from openrl.utils.callbacks.eval_callback import EvalCallback
callbacks_dict = {
"CheckpointCallback": CheckpointCallback,
"EvalCallback": EvalCallback,
}
[docs]class CallbackFactory:
[docs] @staticmethod
def get_callbacks(
callbacks: Union[Dict[str, Any], List[Dict[str, Any]]]
) -> BaseCallback:
if isinstance(callbacks, dict):
callbacks = [callbacks]
callbacks_list = []
for callback in callbacks:
if callback["id"] not in callbacks_dict:
raise ValueError(f"Callback {callback['id']} not found")
callbacks_list.append(callbacks_dict[callback["id"]](**callback["args"]))
return CallbackList(callbacks_list)
[docs] @staticmethod
def register(
id: str,
callback_class: Type[BaseCallback],
):
callbacks_dict[id] = callback_class