Source code for openrl.envs.vec_env.wrappers.gen_data
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
import copy
import pickle
import numpy as np
from gymnasium.core import ActType
from tqdm.rich import tqdm
from openrl.envs.vec_env.base_venv import BaseVecEnv
from openrl.envs.vec_env.wrappers.base_wrapper import VecEnvWrapper
from openrl.envs.wrappers.monitor import Monitor
[docs]class TrajectoryData:
def __init__(
self, env_num, total_episode, observation_space, action_space, agent_num: int
):
self.env_num = env_num
self.all_keys = ["obs", "action", "reward", "done", "info"]
self.total_episode = total_episode
self.all_trajectories = None
self.observation_space = observation_space
self.action_space = action_space
self.agent_num = agent_num
[docs] def init_empty_dict(self, source_dict={}):
for key in self.all_keys:
source_dict[key] = []
return copy.copy(source_dict)
[docs] def reset(self, reset_data):
self.current_total_episode = 0
self.episode_lengths = []
self.episode_rewards = []
self.all_trajectories = self.init_empty_dict()
self.data = [self.init_empty_dict() for _ in range(self.env_num)]
for key in reset_data:
for i in range(self.env_num):
self.data[i][key].append(copy.copy(reset_data[key][i]))
[docs] def step(self, step_data):
step_data = copy.copy(step_data)
finished = []
need_finish = False
for i in range(self.env_num):
if need_finish:
break
done = step_data["done"][i]
if np.all(done):
assert (
"final_info" in step_data["info"][i]
and "episode" in step_data["info"][i]["final_info"]
)
self.episode_lengths.append(
step_data["info"][i]["final_info"]["episode"]["l"]
)
self.episode_rewards.append(
step_data["info"][i]["final_info"]["episode"]["r"]
)
finished.append(i)
# step_data["obs"][i] = step_data["info"][i]["final_observation"]
# step_data["info"][i] = step_data["info"][i]["final_info"]
for key in self.data[i]:
if key in ["obs", "info"]:
if key == "obs":
self.data[i][key].append(
copy.copy(step_data["info"][i]["final_observation"])
)
elif key == "info":
self.data[i][key].append(
copy.copy(step_data["info"][i]["final_info"])
)
assert len(self.data[i][key]) == self.episode_lengths[-1] + 1, (
f"key: {key}, len: {len(self.data[i][key])},"
f" episode_lengths: {self.episode_lengths[-1]}"
)
elif key in ["action", "reward", "done"]:
self.data[i][key].append(copy.copy(step_data[key][i]))
assert len(self.data[i][key]) == self.episode_lengths[-1], (
f"key: {key}, len: {len(self.data[i][key])},"
f" episode_lengths: {self.episode_lengths[-1]}"
)
self.all_trajectories[key].append(copy.copy(self.data[i][key]))
self.data[i] = self.init_empty_dict()
del step_data["info"][i]["final_observation"]
del step_data["info"][i]["final_info"]
for key in ["obs", "info"]:
self.data[i][key].append(copy.copy(step_data[key][i]))
self.current_total_episode += 1
if self.current_total_episode >= self.total_episode:
need_finish = True
else:
for key in step_data:
self.data[i][key].append(copy.copy(step_data[key][i]))
return finished, need_finish
[docs] def dump(self, save_path):
self.all_trajectories["episode_lengths"] = self.episode_lengths
self.all_trajectories["episode_rewards"] = self.episode_rewards
trajectory_num = len(self.all_trajectories["obs"])
for key in self.all_trajectories:
assert len(self.all_trajectories[key]) == trajectory_num, (
f"key: {key}, len: {len(self.all_trajectories[key])}, trajectory_num:"
f" {trajectory_num}"
)
self.all_trajectories["env_info"] = {
"agent_num": self.agent_num,
"observation_space": self.observation_space,
"action_space": self.action_space,
}
with open(save_path, "wb") as f:
pickle.dump(self.all_trajectories, f, protocol=pickle.HIGHEST_PROTOCOL)
print("data saved to: ", save_path)
# this creates dataset with larger size (about 4x) than GenDataWrapper_v1, thus need to be optimized
[docs]class GenDataWrapper(VecEnvWrapper):
def __init__(self, env: BaseVecEnv, data_save_path: str, total_episode: int):
assert env.env_is_wrapped(Monitor, indices=0)[0]
super().__init__(env)
self.data_save_path = data_save_path
self.total_episode = total_episode
self.pbar = None
self.trajectory_data = None
[docs] def reset(self, **kwargs):
self.trajectory_data = TrajectoryData(
self.env.parallel_env_num,
self.total_episode,
self.env.observation_space,
self.env.action_space,
self.env.agent_num,
)
returns = self.env.reset(**kwargs)
if len(returns) == 2:
obs = returns[0]
info = returns[1]
else:
obs = returns
info = {}
reset_data = {"obs": obs, "info": info}
self.trajectory_data.reset(reset_data)
if self.pbar is not None:
self.pbar.refresh()
self.pbar.close()
self.pbar = tqdm(total=self.total_episode)
return returns
[docs] def step(self, action: ActType, *args, **kwargs):
step_data = {}
step_data["action"] = action
obs, r, done, info = self.env.step(action, *args, **kwargs)
step_data["obs"] = obs
step_data["reward"] = r
step_data["done"] = done
step_data["info"] = info
finished, need_finish = self.trajectory_data.step(step_data)
self.pbar.update(len(finished))
if need_finish:
assert self.trajectory_data.current_total_episode == self.total_episode
return obs, r, need_finish, info
[docs] def close(self, **kwargs):
self.pbar.refresh()
self.pbar.close()
average_length = np.mean(self.trajectory_data.episode_lengths)
average_reward = np.mean(self.trajectory_data.episode_rewards)
print(
"collect total episode: {}".format(
self.trajectory_data.current_total_episode
)
)
print("average episode length: {}".format(average_length))
print("average reward: {}".format(average_reward))
self.trajectory_data.dump(self.data_save_path)
return self.env.close(**kwargs)
[docs]class GenDataWrapper_v1(VecEnvWrapper):
def __init__(self, env: BaseVecEnv, data_save_path: str, total_episode: int):
assert env.env_is_wrapped(Monitor, indices=0)[0]
super().__init__(env)
self.data_save_path = data_save_path
self.total_episode = total_episode
self.pbar = None
[docs] def reset(self, **kwargs):
self.current_total_episode = 0
self.episode_lengths = []
self.data = {
"obs": [],
"action": [],
"reward": [],
"done": [],
"info": [],
}
returns = self.env.reset(**kwargs)
if len(returns) == 2:
obs = returns[0]
info = returns[1]
else:
obs = returns
info = {}
self.data["action"].append(None)
self.data["obs"].append(obs)
self.data["reward"].append(None)
self.data["done"].append(None)
self.data["info"].append(info)
if self.pbar is not None:
self.pbar.refresh()
self.pbar.close()
self.pbar = tqdm(total=self.total_episode)
return returns
[docs] def step(self, action: ActType, *args, **kwargs):
self.data["action"].append(action)
obs, r, done, info = self.env.step(action, *args, **kwargs)
self.data["obs"].append(obs)
self.data["reward"].append(r)
self.data["done"].append(done)
self.data["info"].append(info)
for i in range(self.env.parallel_env_num):
if np.all(done[i]):
self.current_total_episode += 1
self.pbar.update(1)
assert "final_info" in info[i] and "episode" in info[i]["final_info"]
self.episode_lengths.append(info[i]["final_info"]["episode"]["l"])
done = self.current_total_episode >= self.total_episode
return obs, r, done, info
[docs] def close(self, **kwargs):
self.pbar.refresh()
self.pbar.close()
average_length = np.mean(self.episode_lengths)
print("collect total episode: {}".format(self.current_total_episode))
print("average episode length: {}".format(average_length))
self.data["total_episode"] = self.current_total_episode
self.data["average_length"] = average_length
pickle.dump(
self.data, open(self.data_save_path, "wb"), protocol=pickle.HIGHEST_PROTOCOL
)
print("data saved to: ", self.data_save_path)
return self.env.close(**kwargs)