Source code for openrl.buffers.offpolicy_replay_data
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from openrl.buffers.replay_data import ReplayData
from openrl.buffers.utils.obs_data import ObsData
from openrl.buffers.utils.util import (
get_critic_obs,
get_critic_obs_space,
get_policy_obs,
get_policy_obs_space,
get_shape_from_act_space,
)
[docs]class OffPolicyReplayData(ReplayData):
def __init__(
self,
cfg,
num_agents,
obs_space,
act_space,
data_client=None,
episode_length=None,
):
super(OffPolicyReplayData, self).__init__(
cfg,
num_agents,
obs_space,
act_space,
data_client,
episode_length,
)
self.act_space = act_space.n
act_shape = get_shape_from_act_space(act_space)
self.actions = np.zeros(
(self.episode_length + 1, self.n_rollout_threads, num_agents, act_shape),
dtype=np.float32,
)
self.value_preds = np.zeros(
(self.episode_length + 1, self.n_rollout_threads, num_agents, act_space.n),
dtype=np.float32,
)
self.rewards = np.zeros(
(self.episode_length + 1, self.n_rollout_threads, num_agents, 1),
dtype=np.float32,
)
policy_obs_shape = get_policy_obs_space(obs_space)
critic_obs_shape = get_critic_obs_space(obs_space)
self.next_policy_obs = np.zeros(
(
self.episode_length + 1,
self.n_rollout_threads,
num_agents,
*policy_obs_shape,
),
dtype=np.float32,
)
self.next_critic_obs = np.zeros(
(
self.episode_length + 1,
self.n_rollout_threads,
num_agents,
*critic_obs_shape,
),
dtype=np.float32,
)
self.first_insert_flag = True
[docs] def dict_insert(self, data):
if self._mixed_obs:
for key in self.critic_obs.keys():
self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy()
for key in self.policy_obs.keys():
self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy()
for key in self.next_policy_obs.keys():
self.next_policy_obs[key][self.step + 1] = data["next_policy_obs"][
key
].copy()
for key in self.next_critic_obs.keys():
self.next_critic_obs[key][self.step + 1] = data["next_critic_obs"][
key
].copy()
else:
self.critic_obs[self.step + 1] = data["critic_obs"].copy()
self.policy_obs[self.step + 1] = data["policy_obs"].copy()
self.next_policy_obs[self.step + 1] = data["next_policy_obs"].copy()
self.next_critic_obs[self.step + 1] = data["next_critic_obs"].copy()
if "rnn_states" in data:
self.rnn_states[self.step + 1] = data["rnn_states"].copy()
if "rnn_states_critic" in data:
self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy()
if "actions" in data:
self.actions[self.step + 1] = data["actions"].copy()
if "action_log_probs" in data:
self.action_log_probs[self.step] = data["action_log_probs"].copy()
if "value_preds" in data:
self.value_preds[self.step] = data["value_preds"].copy()
if "rewards" in data:
self.rewards[self.step + 1] = data["rewards"].copy()
if "masks" in data:
self.masks[self.step + 1] = data["masks"].copy()
if "bad_masks" in data:
self.bad_masks[self.step + 1] = data["bad_masks"].copy()
if "active_masks" in data:
self.active_masks[self.step + 1] = data["active_masks"].copy()
if "available_actions" in data:
self.available_actions[self.step + 1] = data["available_actions"].copy()
if (self.step + 1) % self.episode_length != 0:
self.first_insert_flag = False
self.step = (self.step + 1) % self.episode_length
[docs] def insert(
self,
raw_obs,
next_raw_obs,
rnn_states,
rnn_states_critic,
actions,
action_log_probs,
value_preds,
rewards,
masks,
bad_masks=None,
active_masks=None,
available_actions=None,
):
critic_obs = get_critic_obs(raw_obs)
policy_obs = get_policy_obs(raw_obs)
next_critic_obs = get_critic_obs(next_raw_obs)
next_policy_obs = get_policy_obs(next_raw_obs)
if self._mixed_obs:
for key in self.critic_obs.keys():
self.critic_obs[key][self.step + 1] = critic_obs[key].copy()
for key in self.policy_obs.keys():
self.policy_obs[key][self.step + 1] = policy_obs[key].copy()
for key in self.next_critic_obs.keys():
self.next_critic_obs[key][self.step + 1] = next_critic_obs[key].copy()
for key in self.next_policy_obs.keys():
self.next_policy_obs[key][self.step + 1] = next_policy_obs[key].copy()
else:
self.critic_obs[self.step + 1] = critic_obs.copy()
self.policy_obs[self.step + 1] = policy_obs.copy()
self.next_critic_obs[self.step + 1] = next_critic_obs.copy()
self.next_policy_obs[self.step + 1] = next_policy_obs.copy()
if rnn_states is not None:
self.rnn_states[self.step + 1] = rnn_states.copy()
if rnn_states_critic is not None:
self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy()
self.actions[self.step + 1] = actions.copy()
self.action_log_probs[self.step] = action_log_probs.copy()
self.value_preds[self.step] = value_preds.copy()
self.rewards[self.step + 1] = rewards.copy()
self.masks[self.step + 1] = masks.copy()
if bad_masks is not None:
self.bad_masks[self.step + 1] = bad_masks.copy()
if active_masks is not None:
self.active_masks[self.step + 1] = active_masks.copy()
if available_actions is not None:
self.available_actions[self.step + 1] = available_actions.copy()
# if (self.step + 1) % self.episode_length != 0:
# self.first_insert_flag = False
self.step = (self.step + 1) % self.episode_length
[docs] def after_update(self):
assert self.step == 0, "step:{} episode:{}".format(
self.step, self.episode_length
)
if self._mixed_obs:
for key in self.critic_obs.keys():
self.critic_obs[key][0] = self.critic_obs[key][-1].copy()
for key in self.policy_obs.keys():
self.policy_obs[key][0] = self.policy_obs[key][-1].copy()
else:
self.critic_obs[0] = self.critic_obs[-1].copy()
self.policy_obs[0] = self.policy_obs[-1].copy()
self.rnn_states[0] = self.rnn_states[-1].copy()
self.rnn_states_critic[0] = self.rnn_states_critic[-1].copy()
self.actions[0] = self.actions[-1].copy()
self.masks[0] = self.masks[-1].copy()
self.bad_masks[0] = self.bad_masks[-1].copy()
self.active_masks[0] = self.active_masks[-1].copy()
if self.available_actions is not None:
self.available_actions[0] = self.available_actions[-1].copy()
[docs] def feed_forward_generator(
self,
advantages,
num_mini_batch=None,
mini_batch_size=None,
critic_obs_process_func=None,
):
episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3]
batch_size = n_rollout_threads * episode_length * num_agents
if mini_batch_size is None:
assert (
batch_size >= num_mini_batch
), (
"DQN requires the number of processes ({}) "
"* number of steps ({}) * number of agents ({}) = {} "
"to be greater than or equal to the number of DQN mini batches ({})."
"".format(
n_rollout_threads,
episode_length,
num_agents,
n_rollout_threads * episode_length * num_agents,
num_mini_batch,
)
)
mini_batch_size = batch_size // num_mini_batch
assert (batch_size - n_rollout_threads) >= mini_batch_size
sampler = BatchSampler(
SubsetRandomSampler(range(batch_size - n_rollout_threads)),
mini_batch_size,
drop_last=True,
)
if self._mixed_obs:
critic_obs = {}
policy_obs = {}
next_critic_obs = {}
next_policy_obs = {}
for key in self.critic_obs.keys():
critic_obs[key] = self.critic_obs[key][:-1].reshape(
-1, *self.critic_obs[key].shape[3:]
)
for key in self.policy_obs.keys():
policy_obs[key] = self.policy_obs[key][:-1].reshape(
-1, *self.policy_obs[key].shape[3:]
)
for key in self.next_critic_obs.keys():
next_critic_obs[key] = self.next_critic_obs[key][:-1].reshape(
-1, *self.next_critic_obs[key].shape[3:]
)
for key in self.next_policy_obs.keys():
next_policy_obs[key] = self.next_policy_obs[key][:-1].reshape(
-1, *self.next_policy_obs[key].shape[3:]
)
else:
critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[3:])
policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[3:])
next_critic_obs = self.next_critic_obs[:-1].reshape(
-1, *self.next_critic_obs.shape[3:]
)
next_policy_obs = self.next_policy_obs[:-1].reshape(
-1, *self.next_policy_obs.shape[3:]
)
rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[3:])
rnn_states_critic = self.rnn_states_critic[:-1].reshape(
-1, *self.rnn_states_critic.shape[3:]
)
actions = self.actions.reshape(-1, self.actions.shape[-1])
if self.available_actions is not None:
available_actions = self.available_actions[:-1].reshape(
-1, self.available_actions.shape[-1]
)
value_preds = self.value_preds[:-1].reshape(-1, self.act_space)
rewards = self.rewards.reshape(-1, 1)
masks = self.masks[:-1].reshape(-1, 1)
active_masks = self.active_masks[:-1].reshape(-1, 1)
action_log_probs = self.action_log_probs.reshape(
-1, self.action_log_probs.shape[-1]
)
if advantages is not None:
advantages = advantages.reshape(-1, 1)
for indices in sampler:
# obs size [T+1 N M Dim]-->[T N M Dim]-->[T*N*M,Dim]-->[index,Dim]
if self._mixed_obs:
critic_obs_batch = {}
policy_obs_batch = {}
next_critic_obs_batch = {}
next_policy_obs_batch = {}
for key in critic_obs.keys():
critic_obs_batch[key] = critic_obs[key][indices]
for key in policy_obs.keys():
policy_obs_batch[key] = policy_obs[key][indices]
for key in next_critic_obs.keys():
next_critic_obs_batch[key] = next_critic_obs[key][indices]
for key in next_policy_obs.keys():
next_policy_obs_batch[key] = next_policy_obs[key][indices]
else:
critic_obs_batch = critic_obs[indices]
policy_obs_batch = policy_obs[indices]
next_critic_obs_batch = next_critic_obs[indices]
next_policy_obs_batch = next_policy_obs[indices]
rnn_states_batch = rnn_states[indices]
rnn_states_critic_batch = rnn_states_critic[indices]
actions_batch = actions[indices]
if self.available_actions is not None:
available_actions_batch = available_actions[indices]
else:
available_actions_batch = None
value_preds_batch = value_preds[indices]
rewards_batch = rewards[indices]
masks_batch = masks[indices]
active_masks_batch = active_masks[indices]
old_action_log_probs_batch = action_log_probs[indices]
if advantages is None:
adv_targ = rewards_batch
else:
adv_targ = advantages[indices]
if critic_obs_process_func is not None:
critic_obs_batch = critic_obs_process_func(critic_obs_batch)
yield critic_obs_batch, policy_obs_batch, next_critic_obs_batch, next_policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, rewards_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch
[docs] def feed_forward_generator_old(
self,
advantages,
num_mini_batch=None,
mini_batch_size=None,
critic_obs_process_func=None,
):
episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3]
batch_size = n_rollout_threads * episode_length * num_agents
if mini_batch_size is None:
assert (
batch_size >= num_mini_batch
), (
"DQN requires the number of processes ({}) "
"* number of steps ({}) * number of agents ({}) = {} "
"to be greater than or equal to the number of DQN mini batches ({})."
"".format(
n_rollout_threads,
episode_length,
num_agents,
n_rollout_threads * episode_length * num_agents,
num_mini_batch,
)
)
mini_batch_size = batch_size // num_mini_batch
sampler = BatchSampler(
SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=True
)
if self._mixed_obs:
critic_obs = {}
policy_obs = {}
next_critic_obs = {}
next_policy_obs = {}
for key in self.critic_obs.keys():
critic_obs[key] = self.critic_obs[key][:-1].reshape(
-1, *self.critic_obs[key].shape[3:]
)
for key in self.policy_obs.keys():
policy_obs[key] = self.policy_obs[key][:-1].reshape(
-1, *self.policy_obs[key].shape[3:]
)
for key in self.next_critic_obs.keys():
next_critic_obs[key] = self.next_critic_obs[key][:-1].reshape(
-1, *self.next_critic_obs[key].shape[3:]
)
for key in self.next_policy_obs.keys():
next_policy_obs[key] = self.next_policy_obs[key][:-1].reshape(
-1, *self.next_policy_obs[key].shape[3:]
)
else:
critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[3:])
policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[3:])
next_critic_obs = self.next_critic_obs[:-1].reshape(
-1, *self.next_critic_obs.shape[3:]
)
next_policy_obs = self.next_policy_obs[:-1].reshape(
-1, *self.next_policy_obs.shape[3:]
)
rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[3:])
rnn_states_critic = self.rnn_states_critic[:-1].reshape(
-1, *self.rnn_states_critic.shape[3:]
)
actions = self.actions.reshape(-1, self.actions.shape[-1])
if self.available_actions is not None:
available_actions = self.available_actions[:-1].reshape(
-1, self.available_actions.shape[-1]
)
value_preds = self.value_preds[:-1].reshape(-1, self.act_space)
rewards = self.rewards.reshape(-1, 1)
masks = self.masks[:-1].reshape(-1, 1)
active_masks = self.active_masks[:-1].reshape(-1, 1)
action_log_probs = self.action_log_probs.reshape(
-1, self.action_log_probs.shape[-1]
)
if advantages is not None:
advantages = advantages.reshape(-1, 1)
for indices in sampler:
# obs size [T+1 N M Dim]-->[T N M Dim]-->[T*N*M,Dim]-->[index,Dim]
if self._mixed_obs:
critic_obs_batch = {}
policy_obs_batch = {}
next_critic_obs_batch = {}
next_policy_obs_batch = {}
for key in critic_obs.keys():
critic_obs_batch[key] = critic_obs[key][indices]
for key in policy_obs.keys():
policy_obs_batch[key] = policy_obs[key][indices]
for key in next_critic_obs.keys():
next_critic_obs_batch[key] = next_critic_obs[key][indices]
for key in next_policy_obs.keys():
next_policy_obs_batch[key] = next_policy_obs[key][indices]
else:
critic_obs_batch = critic_obs[indices]
policy_obs_batch = policy_obs[indices]
next_critic_obs_batch = next_critic_obs[indices]
next_policy_obs_batch = next_policy_obs[indices]
rnn_states_batch = rnn_states[indices]
rnn_states_critic_batch = rnn_states_critic[indices]
actions_batch = actions[indices]
if self.available_actions is not None:
available_actions_batch = available_actions[indices]
else:
available_actions_batch = None
value_preds_batch = value_preds[indices]
rewards_batch = rewards[indices]
masks_batch = masks[indices]
active_masks_batch = active_masks[indices]
old_action_log_probs_batch = action_log_probs[indices]
if advantages is None:
adv_targ = rewards_batch
else:
adv_targ = advantages[indices]
if critic_obs_process_func is not None:
critic_obs_batch = critic_obs_process_func(critic_obs_batch)
yield critic_obs_batch, policy_obs_batch, next_critic_obs_batch, next_policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, rewards_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch