Source code for openrl.buffers.offpolicy_replay_data
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
import numpy as np
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from openrl.buffers.replay_data import ReplayData
from openrl.buffers.utils.util import (
get_critic_obs,
get_critic_obs_space,
get_policy_obs,
get_policy_obs_space,
get_shape_from_act_space,
)
[docs]class OffPolicyReplayData(ReplayData):
def __init__(
self,
cfg,
num_agents,
obs_space,
act_space,
data_client=None,
episode_length=None,
):
super(OffPolicyReplayData, self).__init__(
cfg,
num_agents,
obs_space,
act_space,
data_client,
episode_length,
)
if act_space.__class__.__name__ == "Box":
self.act_space = act_space.shape[0]
elif act_space.__class__.__name__ == "Discrete":
self.act_space = act_space.n
act_shape = get_shape_from_act_space(act_space)
self.actions = np.zeros(
(self.episode_length + 1, self.n_rollout_threads, num_agents, act_shape),
dtype=np.float32,
)
self.value_preds = np.zeros(
(
self.episode_length + 1,
self.n_rollout_threads,
num_agents,
self.act_space,
),
dtype=np.float32,
)
self.rewards = np.zeros(
(self.episode_length + 1, self.n_rollout_threads, num_agents, 1),
dtype=np.float32,
)
policy_obs_shape = get_policy_obs_space(obs_space)
critic_obs_shape = get_critic_obs_space(obs_space)
self.next_policy_obs = np.zeros(
(
self.episode_length + 1,
self.n_rollout_threads,
num_agents,
*policy_obs_shape,
),
dtype=np.float32,
)
self.next_critic_obs = np.zeros(
(
self.episode_length + 1,
self.n_rollout_threads,
num_agents,
*critic_obs_shape,
),
dtype=np.float32,
)
self.first_insert_flag = True
# def dict_insert(self, data):
# if self._mixed_obs:
# for key in self.critic_obs.keys():
# self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy()
# for key in self.policy_obs.keys():
# self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy()
# for key in self.next_policy_obs.keys():
# self.next_policy_obs[key][self.step + 1] = data["next_policy_obs"][
# key
# ].copy()
# for key in self.next_critic_obs.keys():
# self.next_critic_obs[key][self.step + 1] = data["next_critic_obs"][
# key
# ].copy()
# else:
# self.critic_obs[self.step + 1] = data["critic_obs"].copy()
# self.policy_obs[self.step + 1] = data["policy_obs"].copy()
# self.next_policy_obs[self.step + 1] = data["next_policy_obs"].copy()
# self.next_critic_obs[self.step + 1] = data["next_critic_obs"].copy()
#
# if "rnn_states" in data:
# self.rnn_states[self.step + 1] = data["rnn_states"].copy()
# if "rnn_states_critic" in data:
# self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy()
# if "actions" in data:
# self.actions[self.step + 1] = data["actions"].copy()
# if "action_log_probs" in data:
# self.action_log_probs[self.step] = data["action_log_probs"].copy()
#
# if "value_preds" in data:
# self.value_preds[self.step] = data["value_preds"].copy()
# if "rewards" in data:
# self.rewards[self.step + 1] = data["rewards"].copy()
# if "masks" in data:
# self.masks[self.step + 1] = data["masks"].copy()
#
# if "bad_masks" in data:
# self.bad_masks[self.step + 1] = data["bad_masks"].copy()
# if "active_masks" in data:
# self.active_masks[self.step + 1] = data["active_masks"].copy()
# if "action_masks" in data:
# self.action_masks[self.step + 1] = data["action_masks"].copy()
#
# if (self.step + 1) % self.episode_length != 0:
# self.first_insert_flag = False
# self.step = (self.step + 1) % self.episode_length
[docs] def init_buffer(self, raw_obs, action_masks=None):
critic_obs = get_critic_obs(raw_obs)
policy_obs = get_policy_obs(raw_obs)
if self._mixed_obs:
for key in self.critic_obs.keys():
self.critic_obs[key][0] = critic_obs[key].copy()
for key in self.policy_obs.keys():
self.policy_obs[key][0] = policy_obs[key].copy()
else:
self.critic_obs[0] = critic_obs.copy()
self.policy_obs[0] = policy_obs.copy()
if action_masks is not None and self.action_masks is not None:
self.action_masks[0] = action_masks
[docs] def insert(
self,
raw_obs,
# next_raw_obs,
rnn_states,
rnn_states_critic,
actions,
action_log_probs,
value_preds,
rewards,
masks,
bad_masks=None,
active_masks=None,
action_masks=None,
):
critic_obs = get_critic_obs(raw_obs)
policy_obs = get_policy_obs(raw_obs)
if self._mixed_obs:
for key in self.critic_obs.keys():
self.critic_obs[key][self.step + 1] = critic_obs[key].copy()
for key in self.policy_obs.keys():
self.policy_obs[key][self.step + 1] = policy_obs[key].copy()
for key in self.next_critic_obs.keys():
self.next_critic_obs[key][self.step] = critic_obs[key].copy()
for key in self.next_policy_obs.keys():
self.next_policy_obs[key][self.step] = policy_obs[key].copy()
else:
self.critic_obs[self.step + 1] = critic_obs.copy()
self.policy_obs[self.step + 1] = policy_obs.copy()
self.next_critic_obs[self.step] = critic_obs.copy()
self.next_policy_obs[self.step] = policy_obs.copy()
if rnn_states is not None:
self.rnn_states[self.step + 1] = rnn_states.copy()
if rnn_states_critic is not None:
self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy()
self.actions[self.step] = actions.copy()
self.action_log_probs[self.step] = action_log_probs.copy()
self.value_preds[self.step] = value_preds.copy()
# self.rewards[self.step + 1] = rewards.copy()
self.masks[self.step + 1] = masks.copy()
self.rewards[self.step] = rewards.copy()
if bad_masks is not None:
self.bad_masks[self.step + 1] = bad_masks.copy()
if active_masks is not None:
self.active_masks[self.step + 1] = active_masks.copy()
if action_masks is not None:
self.action_masks[self.step + 1] = action_masks.copy()
self.step = (self.step + 1) % self.episode_length
[docs] def after_update(self):
assert self.step == 0, "step:{} episode:{}".format(
self.step, self.episode_length
)
if self._mixed_obs:
for key in self.critic_obs.keys():
self.critic_obs[key][0] = self.critic_obs[key][-1].copy()
self.next_critic_obs[key][0] = self.next_critic_obs[key][-1].copy()
for key in self.policy_obs.keys():
self.policy_obs[key][0] = self.policy_obs[key][-1].copy()
self.next_policy_obs[key][0] = self.next_policy_obs[key][-1].copy()
else:
self.critic_obs[0] = self.critic_obs[-1].copy()
self.next_critic_obs[0] = self.next_critic_obs[-1].copy()
self.policy_obs[0] = self.policy_obs[-1].copy()
self.next_policy_obs[0] = self.next_policy_obs[-1].copy()
self.rnn_states[0] = self.rnn_states[-1].copy()
self.rnn_states_critic[0] = self.rnn_states_critic[-1].copy()
self.actions[0] = self.actions[-1].copy()
self.masks[0] = self.masks[-1].copy()
self.bad_masks[0] = self.bad_masks[-1].copy()
self.active_masks[0] = self.active_masks[-1].copy()
if self.action_masks is not None:
self.action_masks[0] = self.action_masks[-1].copy()
[docs] def feed_forward_generator(
self,
advantages,
num_mini_batch=None,
mini_batch_size=None,
critic_obs_process_func=None,
):
episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3]
batch_size = n_rollout_threads * (episode_length - 1) * num_agents
if mini_batch_size is None:
assert (
batch_size >= num_mini_batch
), (
"DQN requires the number of processes ({}) "
"* number of steps ({}) * number of agents ({}) = {} "
"to be greater than or equal to the number of DQN mini batches ({})."
"".format(
n_rollout_threads,
episode_length,
num_agents,
n_rollout_threads * episode_length * num_agents,
num_mini_batch,
)
)
mini_batch_size = batch_size // num_mini_batch
sampler = BatchSampler(
SubsetRandomSampler(range(batch_size)),
mini_batch_size,
drop_last=True,
)
if self._mixed_obs:
critic_obs = {}
policy_obs = {}
next_critic_obs = {}
next_policy_obs = {}
for key in self.critic_obs.keys():
critic_obs[key] = self.critic_obs[key][:-1].reshape(
-1, *self.critic_obs[key].shape[3:]
)
for key in self.policy_obs.keys():
policy_obs[key] = self.policy_obs[key][:-1].reshape(
-1, *self.policy_obs[key].shape[3:]
)
for key in self.next_critic_obs.keys():
next_critic_obs[key] = self.next_critic_obs[key][:-1].reshape(
-1, *self.next_critic_obs[key].shape[3:]
)
for key in self.next_policy_obs.keys():
next_policy_obs[key] = self.next_policy_obs[key][:-1].reshape(
-1, *self.next_policy_obs[key].shape[3:]
)
else:
critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[3:])
policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[3:])
next_critic_obs = self.next_critic_obs[:-1].reshape(
-1, *self.next_critic_obs.shape[3:]
)
next_policy_obs = self.next_policy_obs[:-1].reshape(
-1, *self.next_policy_obs.shape[3:]
)
rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[3:])
rnn_states_critic = self.rnn_states_critic[:-1].reshape(
-1, *self.rnn_states_critic.shape[3:]
)
actions = self.actions.reshape(-1, self.actions.shape[-1])
if self.action_masks is not None:
action_masks = self.action_masks[:-1].reshape(
-1, self.action_masks.shape[-1]
)
value_preds = self.value_preds[:-1].reshape(-1, self.act_space)
rewards = self.rewards.reshape(-1, 1)
masks = self.masks[:-1].reshape(-1, 1)
next_masks = self.masks[1:].reshape(-1, 1)
active_masks = self.active_masks[:-1].reshape(-1, 1)
action_log_probs = self.action_log_probs.reshape(
-1, self.action_log_probs.shape[-1]
)
if advantages is not None:
advantages = advantages.reshape(-1, 1)
for indices in sampler:
# obs size [T+1 N M Dim]-->[T N M Dim]-->[T*N*M,Dim]-->[index,Dim]
if self._mixed_obs:
critic_obs_batch = {}
policy_obs_batch = {}
next_critic_obs_batch = {}
next_policy_obs_batch = {}
for key in critic_obs.keys():
critic_obs_batch[key] = critic_obs[key][indices]
for key in policy_obs.keys():
policy_obs_batch[key] = policy_obs[key][indices]
for key in next_critic_obs.keys():
next_critic_obs_batch[key] = next_critic_obs[key][indices]
for key in next_policy_obs.keys():
next_policy_obs_batch[key] = next_policy_obs[key][indices]
else:
critic_obs_batch = critic_obs[indices]
policy_obs_batch = policy_obs[indices]
next_critic_obs_batch = next_critic_obs[indices]
next_policy_obs_batch = next_policy_obs[indices]
rnn_states_batch = rnn_states[indices]
rnn_states_critic_batch = rnn_states_critic[indices]
actions_batch = actions[indices]
if self.action_masks is not None:
action_masks_batch = action_masks[indices]
else:
action_masks_batch = None
value_preds_batch = value_preds[indices]
rewards_batch = rewards[indices]
masks_batch = masks[indices]
next_masks_batch = next_masks[indices]
active_masks_batch = active_masks[indices]
old_action_log_probs_batch = action_log_probs[indices]
if advantages is None:
adv_targ = rewards_batch
else:
adv_targ = advantages[indices]
if critic_obs_process_func is not None:
critic_obs_batch = critic_obs_process_func(critic_obs_batch)
yield critic_obs_batch, policy_obs_batch, next_critic_obs_batch, next_policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, rewards_batch, masks_batch, next_masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch