Shortcuts

Source code for openrl.rewards

from typing import Any

from openrl.envs.vec_env.base_venv import BaseVecEnv
from openrl.rewards.base_reward import BaseReward

registed_rewards = {
    "default": BaseReward,
}


[docs]class RewardFactory:
[docs] @staticmethod def get_reward_class(reward_class: Any, env: BaseVecEnv): RewardFactory.auto_register(reward_class) if reward_class is None or reward_class.id is None: return registed_rewards["default"](env) return registed_rewards[reward_class.id](env, **reward_class.args)
[docs] @staticmethod def register(reward_name, reward_class): registed_rewards.update({reward_name: reward_class})
[docs] @staticmethod def auto_register(reward_class: Any): if reward_class is None: return if reward_class.id == "NLPReward": from openrl.rewards.nlp_reward import NLPReward registed_rewards.update({"NLPReward": NLPReward}) elif reward_class.id == "GAILReward": from openrl.rewards.gail_reward import GAILReward registed_rewards.update({"GAILReward": GAILReward})