Shortcuts

Source code for openrl.envs.vec_env.vec_info.simple_vec_info

import time
from typing import Any, Dict

import numpy as np

from openrl.envs.vec_env.vec_info.base_vec_info import BaseVecInfo


[docs]class SimpleVecInfo(BaseVecInfo): def __init__(self, parallel_env_num: int, agent_num: int): super().__init__(parallel_env_num, agent_num) self.infos = [] self.start_time = time.time() self.total_step = 0
[docs] def statistics(self, buffer: Any) -> Dict[str, Any]: # this function should be called each episode rewards = buffer.data.rewards.copy() self.total_step += np.prod(rewards.shape[:2]) rewards = rewards.transpose(2, 1, 0, 3) info_dict = {} ep_rewards = [] for i in range(self.agent_num): agent_reward = rewards[i].mean(0).sum() ep_rewards.append(agent_reward) info_dict["agent_{}/rollout_episode_reward".format(i)] = agent_reward info_dict["FPS"] = int(self.total_step / (time.time() - self.start_time)) info_dict["rollout_episode_reward"] = np.mean(ep_rewards) return info_dict
[docs] def append(self, info: Dict[str, Any]) -> None: self.infos.append(info)
[docs] def reset(self) -> None: self.infos = [] self.rewards = []