Source code for openrl.envs.offline.offline_env
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
from typing import Optional
import gymnasium as gym
import numpy as np
from gymnasium.utils import seeding
from openrl.datasets.expert_dataset import ExpertDataset
[docs]class OfflineEnv(gym.Env):
_np_random: Optional[np.random.Generator] = None
env_name = "OfflineEnv"
def __init__(self, dataset_path, env_id: int, env_num: int, seed: int):
self.dataset = ExpertDataset(
dataset_path, env_id=env_id, env_num=env_num, seed=seed
)
self.observation_space = self.dataset.observation_space
self.action_space = self.dataset.action_space
self.agent_num = self.dataset.agent_num
self.traj_num = len(self.dataset.trajectories["episode_lengths"])
self.traj_index = None
self.epoch_index = None
self.traj_length = None
self.step_index = None
self.sample_indexes = None
self.seed(seed)
[docs] def seed(self, seed=None):
if seed is not None:
self._np_random, seed = seeding.np_random(seed)
[docs] def reset(self, *, seed=None, options=None):
if seed is not None:
self.seed(seed)
if self.epoch_index is None:
self.epoch_index = 0
else:
self.epoch_index += 1
self.epoch_index %= self.traj_num
if self.epoch_index == 0:
if self._np_random is None:
self.seed(0)
self.sample_indexes = self._np_random.permutation(self.traj_num)
assert self.sample_indexes is not None
self.traj_index = self.sample_indexes[self.epoch_index]
self.traj_length = self.dataset.trajectories["episode_lengths"][self.traj_index]
assert (
self.traj_length
== len(self.dataset.trajectories["obs"][self.traj_index]) - 1
)
assert self.traj_length == len(
self.dataset.trajectories["action"][self.traj_index]
)
self.step_index = 0
return (
self.dataset.trajectories["obs"][self.traj_index][self.step_index],
self.dataset.trajectories["info"][self.traj_index][self.step_index],
)
[docs] def step(self, action):
obs = self.dataset.trajectories["obs"][self.traj_index][self.step_index + 1]
reward = self.dataset.trajectories["reward"][self.traj_index][self.step_index]
action = self.dataset.trajectories["action"][self.traj_index][self.step_index]
done = self.dataset.trajectories["done"][self.traj_index][self.step_index]
info = self.dataset.trajectories["info"][self.traj_index][self.step_index]
if isinstance(info, list):
info = {"info": info}
info.update({"data_action": action})
self.step_index += 1
return obs, reward, done, info