Source code for openrl.envs.wrappers.extra_wrappers
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
from copy import deepcopy
from typing import Any, Dict, Optional, SupportsFloat, Tuple
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from gymnasium.utils.step_api_compatibility import (
convert_to_terminated_truncated_step_api,
)
from gymnasium.wrappers import AutoResetWrapper, StepAPICompatibility
from openrl.envs.wrappers import BaseObservationWrapper, BaseRewardWrapper, BaseWrapper
from openrl.envs.wrappers.base_wrapper import ActType, ArrayType, WrapperObsType
from openrl.envs.wrappers.flatten import flatten
[docs]class FrameSkip(BaseWrapper):
def __init__(self, env, num_frames: int = 8):
super().__init__(env)
self.num_frames = num_frames
[docs] def step(self, action):
num_skips = self.num_frames
total_reward = 0.0
for x in range(num_skips):
obs, rew, term, trunc, info = super().step(action)
total_reward += rew
if term or trunc:
break
return obs, total_reward, term, trunc, info
[docs]def convert_to_done_step_api(
step_returns,
is_vector_env: bool = False,
):
if len(step_returns) == 4:
return step_returns
else:
assert len(step_returns) == 5
observations, rewards, terminated, truncated, infos = step_returns
# Cases to handle - info single env / info vector env (list) / info vector env (dict)
# if truncated[0]:
# import pdb;
# pdb.set_trace()
if is_vector_env is False:
if isinstance(terminated, list):
infos["TimeLimit.truncated"] = truncated[0] and not terminated[0]
done_return = np.logical_or(terminated, truncated)
else:
if truncated or terminated:
infos["TimeLimit.truncated"] = truncated and not terminated
done_return = terminated or truncated
return (
observations,
rewards,
done_return,
infos,
)
elif isinstance(infos, list):
for info, env_truncated, env_terminated in zip(
infos, truncated, terminated
):
if env_truncated or env_terminated:
info["TimeLimit.truncated"] = env_truncated and not env_terminated
return (
observations,
rewards,
np.logical_or(terminated, truncated),
infos,
)
elif isinstance(infos, dict):
if np.logical_or(np.any(truncated), np.any(terminated)):
infos["TimeLimit.truncated"] = np.logical_and(
truncated, np.logical_not(terminated)
)
return (
observations,
rewards,
np.logical_or(terminated, truncated),
infos,
)
else:
raise TypeError(
"Unexpected value of infos, as is_vector_envs=False, expects `info` to"
f" be a list or dict, actual type: {type(infos)}"
)
[docs]def step_api_compatibility(
step_returns,
output_truncation_bool: bool = True,
is_vector_env: bool = False,
):
if output_truncation_bool:
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)
else:
return convert_to_done_step_api(step_returns, is_vector_env)
[docs]class RemoveTruncated(StepAPICompatibility, BaseWrapper):
def __init__(
self,
env: gym.Env,
):
output_truncation_bool = False
super().__init__(env, output_truncation_bool=output_truncation_bool)
[docs] def step(self, action):
step_returns = self.env.step(action)
return step_api_compatibility(
step_returns, self.output_truncation_bool, self.is_vector_env
)
[docs]class FlattenObservation(BaseObservationWrapper):
def __init__(self, env: gym.Env):
"""Flattens the observations of an environment.
Args:
env: The environment to apply the wrapper
"""
gym.utils.RecordConstructorArgs.__init__(self)
BaseObservationWrapper.__init__(self, env)
self.observation_space = spaces.flatten_space(env.observation_space)
[docs] def observation(self, observation):
"""Flattens an observation.
Args:
observation: The observation to flatten
Returns:
The flattened observation
"""
return flatten(self.env.observation_space, self.agent_num, observation)
[docs]class AddStep(BaseObservationWrapper):
def __init__(self, env: gym.Env):
"""Flattens the observations of an environment.
Args:
env: The environment to apply the wrapper
"""
BaseObservationWrapper.__init__(self, env)
assert isinstance(self.env.observation_space, spaces.Box) or isinstance(
self.env.observation_space, spaces.Discrete
)
if isinstance(self.env.observation_space, spaces.Box):
assert len(self.env.observation_space.shape) == 1
self.observation_space = spaces.Box(
np.append(self.env.observation_space.low, 0),
np.append(self.env.observation_space.high, np.inf),
shape=(self.env.observation_space.shape[0] + 1,),
)
else:
self.observation_space = spaces.Discrete(n=self.env.observation_space.n + 1)
[docs] def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[WrapperObsType, Dict[str, Any]]:
"""Modifies the :attr:`env` after calling :meth:`reset`, returning a modified observation using :meth:`self.observation`."""
self.step_count = 0
return super().reset(seed=seed, options=options)
[docs] def step(
self, action: ActType
) -> Tuple[WrapperObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
self.step_count += 1
return super().step(action)
[docs] def observation(self, observation):
"""Flattens an observation.
Args:
observation: The observation to flatten
Returns:
The flattened observation
"""
new_obs = np.append(observation, self.step_count)
return new_obs
[docs]class MoveActionMask2InfoWrapper(BaseWrapper):
def __init__(
self,
env: gym.Env,
):
super().__init__(env)
self.need_convert = False
if hasattr(self.env.observation_space, "spaces"):
if "action_mask" in self.env.observation_space.spaces.keys():
self.need_convert = True
self.observation_space = self.env.observation_space.spaces[
"observation"
]
[docs] def step(self, action):
results = self.env.step(action)
if self.need_convert:
obs = results[0]["observation"]
info = results[-1]
info["action_masks"] = results[0]["action_mask"]
return obs, *results[1:-1], info
if "action_mask" in results[-1]:
info = results[-1]
info["action_masks"] = info["action_mask"]
del info["action_mask"]
return *results[0:-1], info
return results
[docs] def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
if self.need_convert:
info["action_masks"] = obs["action_mask"]
obs = obs["observation"]
else:
if "action_mask" in info:
info["action_masks"] = info["action_mask"]
del info["action_mask"]
return obs, info
[docs]class AutoReset(AutoResetWrapper, BaseWrapper):
def __init__(
self,
env: gym.Env,
):
super().__init__(env)
@property
def has_auto_reset(self):
return True
[docs]class DictWrapper(BaseObservationWrapper):
def __init__(self, env):
super().__init__(env)
need_convert = "Dict" not in self.env.observation_space.__class__.__name__
if need_convert:
self.observation_space = gym.spaces.Dict(
{
"policy": self.env.observation_space,
"critic": self.env.observation_space,
}
)
[docs] def observation(self, observation):
return {"policy": observation, "critic": deepcopy(observation)}
[docs]class ConvertEmptyBoxWrapper(BaseObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.need_convert = False
self.is_dict = "Dict" in self.env.observation_space.__class__.__name__
self.convert_keys = []
if self.is_dict:
for key in self.env.observation_space.spaces.keys():
if (
isinstance(self.env.observation_space.spaces[key], gym.spaces.Box)
and self.env.observation_space.spaces[key].shape == ()
):
self.need_convert = True
self.convert_keys.append(key)
old_space = self.env.observation_space.spaces[key]
self.env.observation_space.spaces[key] = gym.spaces.Box(
low=np.array([old_space.low]),
high=np.array([old_space.high]),
shape=(1,),
dtype=old_space.dtype,
)
else:
if (
isinstance(self.env.observation_space, gym.spaces.Box)
and self.env.observation_space.shape == ()
):
self.need_convert = True
old_space = self.env.observation_space
self.env.observation_space = gym.spaces.Box(
low=np.array([old_space.low]),
high=np.array([old_space.high]),
shape=(1,),
dtype=old_space.dtype,
)
[docs] def observation(self, observation):
if self.need_convert:
if self.is_dict:
for key in self.convert_keys:
observation[key] = np.array([observation[key]])
else:
observation = np.array([observation])
return observation
[docs]class GIFWrapper(BaseWrapper):
def __init__(self, env, gif_path: str, fps: int = 30):
super().__init__(env)
self.gif_path = gif_path
import imageio
self.writter = imageio.get_writer(
self.gif_path, mode="I", duration=int(1000 / fps)
)
[docs] def reset(self, **kwargs):
results = self.env.reset(**kwargs)
img = self.env.render()
self.writter.append_data(img)
return results
[docs] def step(self, action):
results = self.env.step(action)
img = self.env.render()
self.writter.append_data(img)
return results