Shortcuts

Source code for openrl.envs.vec_env.sync_venv

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
import time
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Optional, Sequence, Union

import numpy as np
from gymnasium import Env
from gymnasium.core import ActType
from gymnasium.spaces import Space

from openrl.envs.vec_env.base_venv import BaseVecEnv
from openrl.envs.vec_env.utils.numpy_utils import (
    concatenate,
    create_empty_array,
    iterate_action,
)


[docs]class SyncVectorEnv(BaseVecEnv): """Vectorized environment that serially runs multiple environments.""" def __init__( self, env_fns: Iterable[Callable[[], Env]], observation_space: Space = None, action_space: Space = None, copy: bool = True, render_mode: Optional[str] = None, auto_reset: bool = True, ): """Vectorized environment that serially runs multiple environments. Args: env_fns: iterable of callable functions that create the environments. observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken. action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken. copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations. Raises: RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment). """ self.viewer = None self.env_fns = env_fns self.envs = [] self.envs += [env_fn() for env_fn in env_fns] self.copy = copy self.metadata = self.envs[0].metadata self._subenv_auto_reset = ( hasattr(self.envs[0], "has_auto_reset") and self.envs[0].has_auto_reset ) if (observation_space is None) or (action_space is None): observation_space = observation_space or self.envs[0].observation_space action_space = action_space or self.envs[0].action_space super().__init__( parallel_env_num=len(self.envs), observation_space=observation_space, action_space=action_space, render_mode=render_mode, auto_reset=auto_reset, ) self._check_spaces() self._agent_num = self.envs[0].agent_num self.observations = create_empty_array( self.observation_space, n=self.parallel_env_num, agent_num=self._agent_num, fn=np.zeros, ) self._rewards = np.zeros( (self.parallel_env_num, self._agent_num, 1), dtype=np.float64 ) self._terminateds = np.zeros( ( self.parallel_env_num, self._agent_num, ), dtype=np.bool_, ) self._truncateds = np.zeros( ( self.parallel_env_num, self._agent_num, ), dtype=np.bool_, ) self._actions = None
[docs] def seed(self, seed: Optional[Union[int, Sequence[int]]] = None): """Sets the seed in all sub-environments. Args: seed: The seed """ super().seed(seed=seed) if seed is None: seed = [None for _ in range(self.parallel_env_num)] if isinstance(seed, int): seed = [seed + i for i in range(self.parallel_env_num)] assert len(seed) == self.parallel_env_num for env, single_seed in zip(self.envs, seed): env.seed(single_seed)
def _reset( self, seed: Optional[Union[int, List[int]]] = None, options: Optional[dict] = None, ): if seed is None: seed = [None for _ in range(self.parallel_env_num)] if isinstance(seed, int): seed = [seed + i * 10086 for i in range(self.parallel_env_num)] assert len(seed) == self.parallel_env_num self._terminateds[:] = False self._truncateds[:] = False observations = [] infos = [] for i, (env, single_seed) in enumerate(zip(self.envs, seed)): kwargs = {} if single_seed is not None: kwargs["seed"] = single_seed if options is not None: kwargs["options"] = options returns = env.reset(**kwargs) if isinstance(returns, tuple): if len(returns) == 2: # obs, info observations.append(returns[0]) infos.append(returns[1]) else: raise NotImplementedError( "Not support reset return length: {}".format(len(returns)) ) else: observations.append(returns) if len(infos) > 0: return self.format_obs(observations), infos else: return self.format_obs(observations)
[docs] def format_obs(self, observations: Iterable) -> Union[tuple, dict, np.ndarray]: self.observations = concatenate( self.observation_space, observations, self.observations ) return deepcopy(self.observations) if self.copy else self.observations
def _step(self, actions: ActType): """Steps through each of the environments returning the batched results. Returns: The batched environment step results """ _actions = iterate_action(self.action_space, actions) observations, infos = [], [] for i, (env, action) in enumerate(zip(self.envs, _actions)): returns = env.step(action) assert isinstance( returns, tuple ), "step return must be tuple, but got: {}".format(type(returns)) _need_reset = not self._subenv_auto_reset if len(returns) == 5: ( observation, self._rewards[i], self._terminateds[i], self._truncateds[i], info, ) = returns need_reset = _need_reset and ( all(self._terminateds[i]) or all(self._truncateds[i]) ) else: ( observation, self._rewards[i], self._terminateds[i], info, ) = returns need_reset = _need_reset and all(self._terminateds[i]) if need_reset and self.auto_reset: old_observation, old_info = observation, info observation, info = env.reset() info = deepcopy(info) info["final_observation"] = old_observation info["final_info"] = old_info observations.append(observation) infos.append(info) if len(returns) == 5: return ( self.format_obs(observations), np.copy(self._rewards), np.copy(self._terminateds), np.copy(self._truncateds), infos, ) elif len(returns) == 4: return ( self.format_obs(observations), np.copy(self._rewards), np.copy(self._terminateds), infos, ) else: raise NotImplementedError( "Not support step return length: {}".format(len(returns)) )
[docs] def close_extras(self, **kwargs): """Close the environments.""" [env.close() for env in self.envs]
def _check_spaces(self) -> bool: for env in self.envs: if not (env.observation_space == self.observation_space): raise RuntimeError( "Some environments have an observation space different from " f"`{self.observation_space}`. In order to batch observations, " "the observation spaces from all environments must be equal." ) if not (env.action_space == self.action_space): raise RuntimeError( "Some environments have an action space different from " f"`{self.action_space}`. In order to batch actions, the " "action spaces from all environments must be equal." ) return True def _get_images(self) -> Sequence[np.ndarray]: if self.render_mode == "single_rgb_array": return [self.envs[0].render()] else: return [env.render() for env in self.envs] @property def env_name(self): if hasattr(self.envs[0], "env_name"): return self.envs[0].env_name elif "name" in self.metadata: self._env_name = self.metadata["name"] else: return self.envs[0].unwrapped.spec.id
[docs] def exec_func( self, func: Callable, indices: Optional[List[int]] = None, *args, **kwargs ) -> tuple: """Calls the method with name and applies args and kwargs. Args: func: The method name *args: The method args **kwargs: The method kwargs Returns: Tuple of results """ results = [] for i, env in enumerate(self.envs): if indices is None or i in indices: if callable(func): results.append(func(env, *args, **kwargs)) else: results.append(func) else: results.append(None) return tuple(results)
[docs] def call(self, name, *args, **kwargs) -> tuple: """Calls the method with name and applies args and kwargs. Args: name: The method name *args: The method args **kwargs: The method kwargs Returns: Tuple of results """ results = [] for env in self.envs: function = getattr(env, name) if callable(function): results.append(function(*args, **kwargs)) else: results.append(function) return tuple(results)
[docs] def set_attr(self, name: str, values: Union[list, tuple, Any]): """Sets an attribute of the sub-environments. Args: name: The property name to change values: Values of the property to be set to. If ``values`` is a list or tuple, then it corresponds to the values for each individual environment, otherwise, a single value is set for all environments. Raises: ValueError: Values must be a list or tuple with length equal to the number of environments. """ if not isinstance(values, (list, tuple)): values = [values for _ in range(self.parallel_env_num)] if len(values) != self.parallel_env_num: raise ValueError( "Values must be a list or tuple with length equal to the " f"number of environments. Got `{len(values)}` values for " f"{self.parallel_env_num} environments." ) for env, value in zip(self.envs, values): setattr(env, name, value)