Source code for openrl.rewards.base_reward
from typing import Any, Dict, List, Union
import numpy as np
[docs]class BaseReward(object):
def __init__(self):
self.step_reward_fn = dict()
self.inner_reward_fn = dict()
self.batch_reward_fn = dict()
[docs] def step_reward(
self, data: Dict[str, Any]
) -> Union[np.ndarray, List[Dict[str, Any]]]:
rewards = data["rewards"].copy()
infos = [dict() for _ in range(rewards.shape[0])]
return rewards, infos