Source code for openrl.envs.vec_env.utils.util
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
from typing import Any, Dict, List, Optional, Sequence
import numpy as np
# source from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/base_vec_env.py#L22
[docs]def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
:param img_nhwc: list or array of images, ndim=4 once turned into array. img nhwc
n = batch index, h = height, w = width, c = channel
:return: img_HWc, ndim=3
"""
img_nhwc = np.asarray(img_nhwc)
n_images, height, width, n_channels = img_nhwc.shape
# new_height was named H before
new_height = int(np.ceil(np.sqrt(n_images)))
# new_width was named W before
new_width = int(np.ceil(float(n_images) / new_height))
img_nhwc = np.array(
list(img_nhwc)
+ [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)]
)
# img_HWhwc
out_image = img_nhwc.reshape((new_height, new_width, height, width, n_channels))
# img_HhWwc
out_image = out_image.transpose(0, 2, 1, 3, 4)
# img_Hh_Ww_c
out_image = out_image.reshape((new_height * height, new_width * width, n_channels))
return out_image
[docs]def prepare_action_masks(
info: Optional[List[Dict[str, Any]]] = None,
agent_num: int = 1,
as_batch: bool = True,
) -> Optional[np.ndarray]:
if info is None:
return None
action_masks = []
for env_index in range(len(info)):
env_info = info[env_index]
action_masks_env = []
for agent_index in range(agent_num):
if env_info is None:
action_mask = None
else:
if "action_masks" in env_info:
mask_dim = len(np.array(env_info["action_masks"]).shape)
if mask_dim == 2:
action_mask = env_info["action_masks"][agent_index]
elif mask_dim == 1:
action_mask = env_info["action_masks"]
else:
raise ValueError(mask_dim)
else:
# if there is no action_masks in env_info, then we assume all actions are available
return None
action_masks_env.append(action_mask)
action_masks.append(action_masks_env)
action_masks = np.array(action_masks, dtype=np.int8)
if as_batch:
action_masks = action_masks.reshape(-1, action_masks.shape[-1])
return action_masks