Source code for openrl.envs.vec_env.utils.share_memory
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
"""Utility functions for vector environments to share memory between processes."""
import multiprocessing as mp
from collections import OrderedDict
from ctypes import c_bool
from functools import singledispatch
from typing import Union
import numpy as np
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import (
Box,
Dict,
Discrete,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
@create_shared_memory.register(Box)
@create_shared_memory.register(Discrete)
@create_shared_memory.register(MultiDiscrete)
@create_shared_memory.register(MultiBinary)
def _create_base_shared_memory(space, n: int = 1, agent_num: int = 1, ctx=mp):
dtype = space.dtype.char
if dtype in "?":
dtype = c_bool
return ctx.Array(dtype, n * agent_num * int(np.prod(space.shape)))
@create_shared_memory.register(Tuple)
def _create_tuple_shared_memory(space, n: int = 1, agent_num: int = 1, ctx=mp):
return tuple(
create_shared_memory(subspace, n=n, agent_num=agent_num, ctx=ctx)
for subspace in space.spaces
)
@create_shared_memory.register(Dict)
def _create_dict_shared_memory(space, n=1, agent_num: int = 1, ctx=mp):
return OrderedDict(
[
(key, create_shared_memory(subspace, n=n, agent_num=agent_num, ctx=ctx))
for (key, subspace) in space.spaces.items()
]
)
@read_from_shared_memory.register(Box)
@read_from_shared_memory.register(Discrete)
@read_from_shared_memory.register(MultiDiscrete)
@read_from_shared_memory.register(MultiBinary)
def _read_base_from_shared_memory(space, shared_memory, n: int = 1, agent_num: int = 1):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(
n,
agent_num,
)
+ space.shape
)
@read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(
space, shared_memory, n: int = 1, agent_num: int = 1
):
return tuple(
read_from_shared_memory(subspace, memory, n=n, agent_num=agent_num)
for (memory, subspace) in zip(shared_memory, space.spaces)
)
@read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space, shared_memory, n: int = 1, agent_num: int = 1):
return OrderedDict(
[
(
key,
read_from_shared_memory(
subspace, shared_memory[key], n=n, agent_num=agent_num
),
)
for (key, subspace) in space.spaces.items()
]
)
@write_to_shared_memory.register(Box)
@write_to_shared_memory.register(Discrete)
@write_to_shared_memory.register(MultiDiscrete)
@write_to_shared_memory.register(MultiBinary)
def _write_base_to_shared_memory(space, agent_num, index, value, shared_memory):
size = agent_num * int(np.prod(space.shape))
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
np.copyto(
destination[index * size : (index + 1) * size],
np.asarray(value, dtype=space.dtype).flatten(),
)
@write_to_shared_memory.register(Tuple)
def _write_tuple_to_shared_memory(space, agent_num, index, values, shared_memory):
for value, memory, subspace in zip(values, shared_memory, space.spaces):
write_to_shared_memory(subspace, agent_num, index, value, memory)
@write_to_shared_memory.register(Dict)
def _write_dict_to_shared_memory(space, agent_num, index, values, shared_memory):
for key, subspace in space.spaces.items():
write_to_shared_memory(
subspace, agent_num, index, values[key], shared_memory[key]
)