Shortcuts

Source code for openrl.envs.common.registration

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
from typing import Callable, Optional

import gymnasium as gym

import openrl
from openrl.envs.PettingZoo.registration import pettingzoo_env_dict
from openrl.envs.vec_env import (
    AsyncVectorEnv,
    BaseVecEnv,
    RewardWrapper,
    SyncVectorEnv,
    VecMonitorWrapper,
)
from openrl.envs.vec_env.vec_info import VecInfoFactory
from openrl.rewards import RewardFactory


[docs]def make( id: str, env_num: int = 1, asynchronous: bool = False, add_monitor: bool = True, render_mode: Optional[str] = None, make_custom_envs: Optional[Callable] = None, auto_reset: bool = True, **kwargs, ) -> BaseVecEnv: cfg = kwargs.get("cfg", None) if render_mode in [None, "human", "rgb_array"]: convert_render_mode = render_mode elif render_mode in ["group_human", "group_rgb_array"]: # will display all the envs (when render_mode == "group_human") # or return all the envs' images (when render_mode == "group_rgb_array") convert_render_mode = "rgb_array" elif render_mode == "single_human": # will only display the first env convert_render_mode = [None] * (env_num - 1) convert_render_mode = ["human"] + convert_render_mode render_mode = None elif render_mode == "single_rgb_array": # env.render() will only return the first env's image convert_render_mode = [None] * (env_num - 1) convert_render_mode = ["rgb_array"] + convert_render_mode else: raise NotImplementedError(f"render_mode {render_mode} is not supported.") if make_custom_envs is not None: env_fns = make_custom_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) else: if id.startswith("pybullet_drones/"): from openrl.envs.gym_pybullet_drones import make_single_agent_drone_envs env_fns = make_single_agent_drone_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) elif id.startswith("snakes_"): from openrl.envs.snake import make_snake_envs env_fns = make_snake_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) elif id.startswith("GymV21Environment-v0:") or id.startswith( "GymV26Environment-v0:" ): from openrl.envs.gymnasium import make_old_gym_envs env_fns = make_old_gym_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) elif id in gym.envs.registry.keys(): from openrl.envs.gymnasium import make_gym_envs env_fns = make_gym_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) elif id in openrl.envs.mpe_all_envs: from openrl.envs.mpe import make_mpe_envs env_fns = make_mpe_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) elif id in openrl.envs.nlp_all_envs: from openrl.envs.nlp import make_nlp_envs env_fns = make_nlp_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs, ) elif id in openrl.envs.toy_all_envs: from openrl.envs.toy_envs import make_toy_envs env_fns = make_toy_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) elif id[0:14] in openrl.envs.super_mario_all_envs: from openrl.envs.super_mario import make_super_mario_envs env_fns = make_super_mario_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) elif id in openrl.envs.connect_all_envs: from openrl.envs.connect_env import make_connect_envs env_fns = make_connect_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) elif id in openrl.envs.gridworld_all_envs: from openrl.envs.gridworld import make_gridworld_envs env_fns = make_gridworld_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) elif id in openrl.envs.offline_all_envs: from openrl.envs.offline import make_offline_envs assert cfg.expert_data is not None, ( "expert_data must be provided for offline envs, you can set it in a" " YAML file or with `--expert_data dataset_path`" ) kwargs["seed"] = cfg.seed env_fns = make_offline_envs( dataset=cfg.expert_data, env_num=env_num, render_mode=convert_render_mode, **kwargs, ) elif id in openrl.envs.pettingzoo_all_envs or id in pettingzoo_env_dict.keys(): from openrl.envs.PettingZoo import make_PettingZoo_envs env_fns = make_PettingZoo_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) else: raise NotImplementedError(f"env {id} is not supported.") if asynchronous: env = AsyncVectorEnv(env_fns, render_mode=render_mode, auto_reset=auto_reset) else: env = SyncVectorEnv(env_fns, render_mode=render_mode, auto_reset=auto_reset) reward_class = cfg.reward_class if cfg else None reward_class = RewardFactory.get_reward_class(reward_class, env) env = RewardWrapper(env, reward_class) if add_monitor: vec_info_class = cfg.vec_info_class if cfg else None vec_info_class = VecInfoFactory.get_vec_info_class(vec_info_class, env) env = VecMonitorWrapper(vec_info_class, env) return env