Shortcuts

openrl.rewards.base_reward 源代码

from typing import Any, Dict, List, Union

import numpy as np


[文档]class BaseReward(object): def __init__(self): self.step_reward_fn = dict() self.inner_reward_fn = dict() self.batch_reward_fn = dict()
[文档] def step_reward( self, data: Dict[str, Any] ) -> Union[np.ndarray, List[Dict[str, Any]]]: rewards = data["reward"].copy() infos = [] for rew_func in self.step_rew_funcs.values(): new_rew, new_info = rew_func(data) if len(infos) == 0: infos = new_info else: for i in range(len(infos)): infos[i].update(new_info[i]) rewards += new_rew return rewards, infos
[文档] def batch_rewards(self, buffer: Any) -> Dict[str, Any]: infos = dict() for rew_func in self.batch_rew_funcs.values(): new_rew, new_info = rew_func() if len(infos) == 0: infos = new_info else: infos.update(new_info) # update rewards, and infos here return dict()