Shortcuts

Source code for openrl.datasets.expert_dataset

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

import pickle

import numpy as np
import torch.utils.data


[docs]class ExpertDataset(torch.utils.data.Dataset): def __init__( self, file_name, num_trajectories=None, subsample_frequency=1, seed=None, env_id=0, env_num=1, ): # if num_trajectories=4, subsample_frequency=20, then use the data of 4 trajectories, and the size of the data of each trajectory is reduced by 20 times if seed is not None: torch.manual_seed(seed) assert num_trajectories is None or env_num == 1 assert ( env_id < env_num ), "env_id must be less than env_num, but got env_id={}, env_num={}".format( env_id, env_num ) self.env_id = env_id self.env_num = env_num all_trajectories = pickle.load(open(file_name, "rb")) if num_trajectories is None: all_trajectory_num = len(all_trajectories["episode_lengths"]) assert env_num <= all_trajectory_num, ( "env_num must be less than all_trajectory_num, but got env_num={}," " all_trajectory_num={}".format(env_num, all_trajectory_num) ) start_traj_idx = all_trajectory_num // env_num * env_id end_traj_idx = all_trajectory_num // env_num * (env_id + 1) else: start_traj_idx = 0 end_traj_idx = num_trajectories num_trajectories = end_traj_idx - start_traj_idx perm = torch.randperm(len(all_trajectories["episode_lengths"])) if "env_info" in all_trajectories: if "observation_space" in all_trajectories["env_info"]: self.observation_space = all_trajectories["env_info"][ "observation_space" ] else: self.observation_space = None # get observation space from obs data if "action_space" in all_trajectories["env_info"]: self.action_space = all_trajectories["env_info"]["action_space"] else: self.action_space = None # get action space from action data if "agent_num" in all_trajectories["env_info"]: self.agent_num = all_trajectories["env_info"]["agent_num"] else: self.agent_num = None # get agent num from obs data del all_trajectories["env_info"] idx = perm[start_traj_idx:end_traj_idx] self.trajectories = {} start_idx = torch.randint( 0, subsample_frequency, size=(num_trajectories,) ).long() for k, v in all_trajectories.items(): data = [v[ii] for ii in idx] if k != "episode_lengths" and k != "episode_rewards": samples = [] for i in range(num_trajectories): samples.append(data[i][start_idx[i] :: subsample_frequency]) self.trajectories[k] = samples elif k == "episode_lengths": self.trajectories[k] = np.array( [data[i] // subsample_frequency for i in range(num_trajectories)] ) self.i2traj_idx = {} self.i2i = {} self.length = np.sum(self.trajectories["episode_lengths"]) traj_idx = 0 i = 0 self.get_idx = [] for j in range(self.length): while self.trajectories["episode_lengths"][traj_idx].item() <= i: i -= self.trajectories["episode_lengths"][traj_idx].item() traj_idx += 1 self.get_idx.append((traj_idx, i)) i += 1 def __len__(self): return self.length def __getitem__(self, i): traj_idx, step_i = self.get_idx[i] return ( self.trajectories["obs"][traj_idx][step_i], self.trajectories["action"][traj_idx][step_i], )