Shortcuts

Source code for openrl.buffers.replay_data

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

from collections import defaultdict

import numpy as np
import torch
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

from openrl.buffers.utils.obs_data import ObsData
from openrl.buffers.utils.util import (
    _cast,
    _cast_v3,
    _flatten,
    _flatten_v3,
    _shuffle_agent_grid,
    get_critic_obs,
    get_critic_obs_space,
    get_policy_obs,
    get_policy_obs_space,
    get_shape_from_act_space,
)


[docs]class ReplayData(object): def __init__( self, cfg, num_agents, obs_space, act_space, data_client=None, episode_length=None, ): if episode_length is None: episode_length = cfg.episode_length self.episode_length = episode_length self.n_rollout_threads = cfg.n_rollout_threads if hasattr(cfg, "rnn_hidden_size"): self.hidden_size = cfg.rnn_hidden_size else: self.hidden_size = cfg.hidden_size self.recurrent_N = cfg.recurrent_N self.gamma = cfg.gamma self.gae_lambda = cfg.gae_lambda self._use_gae = cfg.use_gae self._use_popart = cfg.use_popart self._use_valuenorm = cfg.use_valuenorm self._use_proper_time_limits = cfg.use_proper_time_limits self._mixed_obs = False # for mixed observation policy_obs_shape = get_policy_obs_space(obs_space) critic_obs_shape = get_critic_obs_space(obs_space) # for mixed observation if "Dict" in policy_obs_shape.__class__.__name__: self._mixed_obs = True self.policy_obs = {} self.critic_obs = {} for key in policy_obs_shape: self.policy_obs[key] = np.zeros( ( self.episode_length + 1, self.n_rollout_threads, num_agents, *policy_obs_shape[key].shape, ), dtype=np.float32, ) for key in critic_obs_shape: self.critic_obs[key] = np.zeros( ( self.episode_length + 1, self.n_rollout_threads, num_agents, *critic_obs_shape[key].shape, ), dtype=np.float32, ) self.policy_obs = ObsData(self.policy_obs) self.critic_obs = ObsData(self.critic_obs) else: # deal with special attn format if type(policy_obs_shape[-1]) == list: policy_obs_shape[:1] if type(critic_obs_shape[-1]) == list: critic_obs_shape = critic_obs_shape[:1] self.critic_obs = np.zeros( ( self.episode_length + 1, self.n_rollout_threads, num_agents, *critic_obs_shape, ), dtype=np.float32, ) self.policy_obs = np.zeros( ( self.episode_length + 1, self.n_rollout_threads, num_agents, *policy_obs_shape, ), dtype=np.float32, ) self.rnn_states = np.zeros( ( self.episode_length + 1, self.n_rollout_threads, num_agents, self.recurrent_N, self.hidden_size, ), dtype=np.float32, ) self.rnn_states_critic = np.zeros_like(self.rnn_states) self.value_preds = np.zeros( (self.episode_length + 1, self.n_rollout_threads, num_agents, 1), dtype=np.float32, ) self.returns = np.zeros_like(self.value_preds) if act_space.__class__.__name__ == "Discrete": self.action_masks = np.ones( ( self.episode_length + 1, self.n_rollout_threads, num_agents, act_space.n, ), dtype=np.float32, ) else: self.action_masks = None act_shape = get_shape_from_act_space(act_space) self.actions = np.zeros( (self.episode_length, self.n_rollout_threads, num_agents, act_shape), dtype=np.float32, ) self.action_log_probs = np.zeros( (self.episode_length, self.n_rollout_threads, num_agents, act_shape), dtype=np.float32, ) self.rewards = np.zeros( (self.episode_length, self.n_rollout_threads, num_agents, 1), dtype=np.float32, ) self.masks = np.ones( (self.episode_length + 1, self.n_rollout_threads, num_agents, 1), dtype=np.float32, ) self.bad_masks = np.ones_like(self.masks) self.active_masks = np.ones_like(self.masks) self.step = 0
[docs] def get_batch_data( self, data_name: str, step: int, ): assert hasattr(self, data_name) data = getattr(self, data_name) if data is None: return None if isinstance(data, ObsData): return data.step_batch(step) else: return np.concatenate(data[step])
# def all_batch_data(self, data_name: str, min=None, max=None): # assert hasattr(self, data_name) # data = getattr(self, data_name) # # if isinstance(data, ObsData): # return data.all_batch(min, max) # else: # return data[min:max].reshape((-1, *data.shape[3:])) # def dict_insert(self, data): # if self._mixed_obs: # for key in self.critic_obs.keys(): # self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy() # for key in self.policy_obs.keys(): # self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy() # else: # self.critic_obs[self.step + 1] = data["critic_obs"].copy() # self.policy_obs[self.step + 1] = data["policy_obs"].copy() # # if "rnn_states" in data: # self.rnn_states[self.step + 1] = data["rnn_states"].copy() # if "rnn_states_critic" in data: # self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy() # if "actions" in data: # self.actions[self.step] = data["actions"].copy() # if "action_log_probs" in data: # self.action_log_probs[self.step] = data["action_log_probs"].copy() # # if "value_preds" in data: # self.value_preds[self.step] = data["value_preds"].copy() # if "rewards" in data: # self.rewards[self.step] = data["rewards"].copy() # if "masks" in data: # self.masks[self.step + 1] = data["masks"].copy() # # if "bad_masks" in data: # self.bad_masks[self.step + 1] = data["bad_masks"].copy() # if "active_masks" in data: # self.active_masks[self.step + 1] = data["active_masks"].copy() # if "action_masks" in data: # self.action_masks[self.step + 1] = data["action_masks"].copy() # # self.step = (self.step + 1) % self.episode_length
[docs] def insert( self, raw_obs, rnn_states, rnn_states_critic, actions, action_log_probs, value_preds, rewards, masks, bad_masks=None, active_masks=None, action_masks=None, ): critic_obs = get_critic_obs(raw_obs) policy_obs = get_policy_obs(raw_obs) if self._mixed_obs: for key in self.critic_obs.keys(): self.critic_obs[key][self.step + 1] = critic_obs[key].copy() for key in self.policy_obs.keys(): self.policy_obs[key][self.step + 1] = policy_obs[key].copy() else: self.critic_obs[self.step + 1] = critic_obs.copy() self.policy_obs[self.step + 1] = policy_obs.copy() if rnn_states is not None: self.rnn_states[self.step + 1] = rnn_states.copy() if rnn_states_critic is not None: self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy() self.actions[self.step] = actions.copy() self.action_log_probs[self.step] = action_log_probs.copy() self.value_preds[self.step] = value_preds.copy() self.rewards[self.step] = rewards.copy() self.masks[self.step + 1] = masks.copy() if bad_masks is not None: self.bad_masks[self.step + 1] = bad_masks.copy() if active_masks is not None: self.active_masks[self.step + 1] = active_masks.copy() if action_masks is not None: self.action_masks[self.step + 1] = action_masks.copy() self.step = (self.step + 1) % self.episode_length
[docs] def init_buffer(self, raw_obs, action_masks=None): critic_obs = get_critic_obs(raw_obs) policy_obs = get_policy_obs(raw_obs) if self._mixed_obs: for key in self.critic_obs.keys(): self.critic_obs[key][0] = critic_obs[key].copy() for key in self.policy_obs.keys(): self.policy_obs[key][0] = policy_obs[key].copy() else: self.critic_obs[0] = critic_obs.copy() self.policy_obs[0] = policy_obs.copy() if action_masks is not None and self.action_masks is not None: self.action_masks[0] = action_masks
[docs] def after_update(self): assert self.step == 0, "step:{} episode:{}".format( self.step, self.episode_length ) if self._mixed_obs: for key in self.critic_obs.keys(): self.critic_obs[key][0] = self.critic_obs[key][-1].copy() for key in self.policy_obs.keys(): self.policy_obs[key][0] = self.policy_obs[key][-1].copy() else: self.critic_obs[0] = self.critic_obs[-1].copy() self.policy_obs[0] = self.policy_obs[-1].copy() self.rnn_states[0] = self.rnn_states[-1].copy() self.rnn_states_critic[0] = self.rnn_states_critic[-1].copy() self.masks[0] = self.masks[-1].copy() self.bad_masks[0] = self.bad_masks[-1].copy() self.active_masks[0] = self.active_masks[-1].copy() if self.action_masks is not None: self.action_masks[0] = self.action_masks[-1].copy()
[docs] def compute_returns(self, next_value, value_normalizer=None): if self._use_proper_time_limits: if self._use_gae: self.value_preds[-1] = next_value gae = 0 for step in reversed(range(self.rewards.shape[0])): if ( self._use_popart or self._use_valuenorm ) and value_normalizer is not None: # step + 1 delta = ( self.rewards[step] + self.gamma * value_normalizer.denormalize(self.value_preds[step + 1]) * self.masks[step + 1] - value_normalizer.denormalize(self.value_preds[step]) ) gae = ( delta + self.gamma * self.gae_lambda * gae * self.masks[step + 1] ) gae = gae * self.bad_masks[step + 1] self.returns[step] = gae + value_normalizer.denormalize( self.value_preds[step] ) else: delta = ( self.rewards[step] + self.gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step] ) gae = ( delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae ) gae = gae * self.bad_masks[step + 1] self.returns[step] = gae + self.value_preds[step] else: self.returns[-1] = next_value for step in reversed(range(self.rewards.shape[0])): if ( self._use_popart or self._use_valuenorm ) and value_normalizer is not None: self.returns[step] = ( self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step] ) * self.bad_masks[step + 1] + ( 1 - self.bad_masks[step + 1] ) * value_normalizer.denormalize( self.value_preds[step] ) else: self.returns[step] = ( self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step] ) * self.bad_masks[step + 1] + ( 1 - self.bad_masks[step + 1] ) * self.value_preds[ step ] else: if self._use_gae: self.value_preds[-1] = next_value gae = 0 for step in reversed(range(self.rewards.shape[0])): if ( self._use_popart or self._use_valuenorm ) and value_normalizer is not None: delta = ( self.rewards[step] + self.gamma * value_normalizer.denormalize(self.value_preds[step + 1]) * self.masks[step + 1] - value_normalizer.denormalize(self.value_preds[step]) ) gae = ( delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae ) self.returns[step] = gae + value_normalizer.denormalize( self.value_preds[step] ) else: delta = ( self.rewards[step] + self.gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step] ) gae = ( delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae ) self.returns[step] = gae + self.value_preds[step] else: self.returns[-1] = next_value for step in reversed(range(self.rewards.shape[0])): self.returns[step] = ( self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step] )
[docs] def recurrent_generator_v3(self, advantages, num_mini_batch, data_chunk_length): episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] batch_size = n_rollout_threads * episode_length data_chunks = batch_size // data_chunk_length # [C=r*T*M/L] mini_batch_size = data_chunks // num_mini_batch assert n_rollout_threads * episode_length >= data_chunk_length, ( "PPO requires the nfumber of processes ({}) * episode length ({}) " "to be greater than or equal to the number of " "data chunk length ({}).".format( n_rollout_threads, episode_length, data_chunk_length ) ) rand = torch.randperm(data_chunks).numpy() sampler = [ rand[i * mini_batch_size : (i + 1) * mini_batch_size] for i in range(num_mini_batch) ] critic_obs = _cast_v3(self.critic_obs[:-1]) policy_obs = _cast_v3(self.policy_obs[:-1]) actions = _cast_v3(self.actions) action_log_probs = _cast_v3(self.action_log_probs) advantages = _cast_v3(advantages) value_preds = _cast_v3(self.value_preds[:-1]) returns = _cast_v3(self.returns[:-1]) masks = _cast_v3(self.masks[:-1]) active_masks = _cast_v3(self.active_masks[:-1]) rnn_states = ( self.rnn_states[:-1] .transpose(1, 0, 2, 3, 4) .reshape(-1, *self.rnn_states.shape[2:]) ) rnn_states_critic = ( self.rnn_states_critic[:-1] .transpose(1, 0, 2, 3, 4) .reshape(-1, *self.rnn_states_critic.shape[2:]) ) if self.action_masks is not None: action_masks = _cast_v3(self.action_masks[:-1]) for indices in sampler: critic_obs_batch = [] policy_obs_batch = [] rnn_states_batch = [] rnn_states_critic_batch = [] actions_batch = [] action_masks_batch = [] value_preds_batch = [] return_batch = [] masks_batch = [] active_masks_batch = [] old_action_log_probs_batch = [] adv_targ = [] for index in indices: ind = index * data_chunk_length # size [T+1 N M Dim]-->[T N M Dim]-->[N,M,T,Dim]-->[N*M*T,Dim]-->[L,Dim] # [L, agent_num, Dim] critic_obs_batch.append(critic_obs[ind : ind + data_chunk_length]) policy_obs_batch.append(policy_obs[ind : ind + data_chunk_length]) actions_batch.append(actions[ind : ind + data_chunk_length]) if self.action_masks is not None: action_masks_batch.append( action_masks[ind : ind + data_chunk_length] ) value_preds_batch.append(value_preds[ind : ind + data_chunk_length]) return_batch.append(returns[ind : ind + data_chunk_length]) masks_batch.append(masks[ind : ind + data_chunk_length]) active_masks_batch.append(active_masks[ind : ind + data_chunk_length]) old_action_log_probs_batch.append( action_log_probs[ind : ind + data_chunk_length] ) adv_targ.append(advantages[ind : ind + data_chunk_length]) # size [T+1 N M Dim]-->[T N M Dim]-->[N M T Dim]-->[N*M*T,Dim]-->[1,Dim] # [1,agent_num, Dim] rnn_states_batch.append(rnn_states[ind]) rnn_states_critic_batch.append(rnn_states_critic[ind]) L, N = data_chunk_length, mini_batch_size # These are all from_numpys of size (L, N, agent_num, Dim) critic_obs_batch = np.stack(critic_obs_batch, axis=1) policy_obs_batch = np.stack(policy_obs_batch, axis=1) actions_batch = np.stack(actions_batch, axis=1) if self.action_masks is not None: action_masks_batch = np.stack(action_masks_batch, axis=1) value_preds_batch = np.stack(value_preds_batch, axis=1) return_batch = np.stack(return_batch, axis=1) masks_batch = np.stack(masks_batch, axis=1) active_masks_batch = np.stack(active_masks_batch, axis=1) old_action_log_probs_batch = np.stack(old_action_log_probs_batch, axis=1) adv_targ = np.stack(adv_targ, axis=1) # States is just a (N, agent_num, -1) from_numpy rnn_states_batch = np.stack(rnn_states_batch).reshape( N * num_agents, *self.rnn_states.shape[3:] ) rnn_states_critic_batch = np.stack(rnn_states_critic_batch).reshape( N * num_agents, *self.rnn_states_critic.shape[3:] ) # Flatten the (L, N, ...) from_numpys to (L * N, ...) critic_obs_batch = _flatten_v3(L, N, num_agents, critic_obs_batch) policy_obs_batch = _flatten_v3(L, N, num_agents, policy_obs_batch) actions_batch = _flatten_v3(L, N, num_agents, actions_batch) if self.action_masks is not None: action_masks_batch = _flatten_v3(L, N, num_agents, action_masks_batch) else: action_masks_batch = None value_preds_batch = _flatten_v3(L, N, num_agents, value_preds_batch) return_batch = _flatten_v3(L, N, num_agents, return_batch) masks_batch = _flatten_v3(L, N, num_agents, masks_batch) active_masks_batch = _flatten_v3(L, N, num_agents, active_masks_batch) old_action_log_probs_batch = _flatten_v3( L, N, num_agents, old_action_log_probs_batch ) adv_targ = _flatten_v3(L, N, num_agents, adv_targ) yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch
[docs] def feed_forward_generator( self, advantages, num_mini_batch=None, mini_batch_size=None, critic_obs_process_func=None, ): episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] batch_size = n_rollout_threads * episode_length * num_agents if mini_batch_size is None: assert ( batch_size >= num_mini_batch ), ( "PPO requires the number of processes ({}) " "* number of steps ({}) * number of agents ({}) = {} " "to be greater than or equal to the number of PPO mini batches ({})." "".format( n_rollout_threads, episode_length, num_agents, n_rollout_threads * episode_length * num_agents, num_mini_batch, ) ) mini_batch_size = batch_size // num_mini_batch sampler = BatchSampler( SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=True ) if self._mixed_obs: critic_obs = {} policy_obs = {} for key in self.critic_obs.keys(): critic_obs[key] = self.critic_obs[key][:-1].reshape( -1, *self.critic_obs[key].shape[3:] ) for key in self.policy_obs.keys(): policy_obs[key] = self.policy_obs[key][:-1].reshape( -1, *self.policy_obs[key].shape[3:] ) else: critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[3:]) policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[3:]) rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[3:]) rnn_states_critic = self.rnn_states_critic[:-1].reshape( -1, *self.rnn_states_critic.shape[3:] ) actions = self.actions.reshape(-1, self.actions.shape[-1]) if self.action_masks is not None: action_masks = self.action_masks[:-1].reshape( -1, self.action_masks.shape[-1] ) value_preds = self.value_preds[:-1].reshape(-1, 1) returns = self.returns[:-1].reshape(-1, 1) masks = self.masks[:-1].reshape(-1, 1) active_masks = self.active_masks[:-1].reshape(-1, 1) action_log_probs = self.action_log_probs.reshape( -1, self.action_log_probs.shape[-1] ) if advantages is not None: advantages = advantages.reshape(-1, 1) for indices in sampler: # obs size [T+1 N M Dim]-->[T N M Dim]-->[T*N*M,Dim]-->[index,Dim] if self._mixed_obs: critic_obs_batch = {} policy_obs_batch = {} for key in critic_obs.keys(): critic_obs_batch[key] = critic_obs[key][indices] for key in policy_obs.keys(): policy_obs_batch[key] = policy_obs[key][indices] else: critic_obs_batch = critic_obs[indices] policy_obs_batch = policy_obs[indices] rnn_states_batch = rnn_states[indices] rnn_states_critic_batch = rnn_states_critic[indices] actions_batch = actions[indices] if self.action_masks is not None: action_masks_batch = action_masks[indices] else: action_masks_batch = None value_preds_batch = value_preds[indices] return_batch = returns[indices] masks_batch = masks[indices] active_masks_batch = active_masks[indices] old_action_log_probs_batch = action_log_probs[indices] if advantages is None: adv_targ = None else: adv_targ = advantages[indices] if critic_obs_process_func is not None: critic_obs_batch = critic_obs_process_func(critic_obs_batch) yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch
[docs] def feed_forward_critic_obs_generator( self, advantages, num_mini_batch=None, mini_batch_size=None, critic_obs_process_func=None, ): episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] batch_size = n_rollout_threads * episode_length if mini_batch_size is None: assert ( batch_size >= num_mini_batch ), ( "PPO requires the number of processes ({}) " "* number of steps ({}) * number of agents ({}) = {} " "to be greater than or equal to the number of PPO mini batches ({})." "".format( n_rollout_threads, episode_length, num_agents, n_rollout_threads * episode_length, num_mini_batch, ) ) mini_batch_size = batch_size // num_mini_batch sampler = BatchSampler( SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=True ) if self._mixed_obs: critic_obs = {} for key in self.critic_obs.keys(): critic_obs[key] = self.critic_obs[key][:-1].reshape( -1, *self.critic_obs[key].shape[3:] ) else: critic_obs = self.critic_obs[:-1, :, 0].reshape( -1, *self.critic_obs.shape[3:] ) # [T*N,Dim] actions = self.actions[:, :, 0].reshape(-1, self.actions.shape[-1]) for indices in sampler: # T is episode length, N is rollout number, M is agent number, dim is dimension # critic_obs size [T+1 N M Dim]-->[T N Dim]-->[T*N,Dim]-->[index,Dim] if self._mixed_obs: critic_obs_batch = {} for key in critic_obs.keys(): critic_obs_batch[key] = critic_obs[key][indices] else: critic_obs_batch = critic_obs[indices] actions_batch = actions[indices] if critic_obs_process_func is not None: critic_obs_batch = critic_obs_process_func(critic_obs_batch) yield critic_obs_batch, None, None, None, actions_batch, None, None, None, None, None, None, None
[docs] def feed_forward_generator_transformer( self, advantages, num_mini_batch=None, mini_batch_size=None ): """ Yield training data for MLP policies. :param advantages: (np.ndarray) advantage estimates. :param num_mini_batch: (int) number of minibatches to split the batch into. :param mini_batch_size: (int) number of samples in each minibatch. """ episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] batch_size = n_rollout_threads * episode_length if mini_batch_size is None: assert ( batch_size >= num_mini_batch ), ( "PPO requires the number of processes ({}) " "* number of steps ({}) = {} " "to be greater than or equal to the number of PPO mini batches ({})." "".format( n_rollout_threads, episode_length, n_rollout_threads * episode_length, num_mini_batch, ) ) mini_batch_size = batch_size // num_mini_batch rand = torch.randperm(batch_size).numpy() sampler = [ rand[i * mini_batch_size : (i + 1) * mini_batch_size] for i in range(num_mini_batch) ] rows, cols = _shuffle_agent_grid(batch_size, num_agents) # keep (num_agent, dim) critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[2:]) critic_obs = critic_obs[rows, cols] policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[2:]) policy_obs = policy_obs[rows, cols] rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[2:]) rnn_states = rnn_states[rows, cols] rnn_states_critic = self.rnn_states_critic[:-1].reshape( -1, *self.rnn_states_critic.shape[2:] ) rnn_states_critic = rnn_states_critic[rows, cols] actions = self.actions.reshape(-1, *self.actions.shape[2:]) actions = actions[rows, cols] if self.action_masks is not None: action_masks = self.action_masks[:-1].reshape( -1, *self.action_masks.shape[2:] ) action_masks = action_masks[rows, cols] value_preds = self.value_preds[:-1].reshape(-1, *self.value_preds.shape[2:]) value_preds = value_preds[rows, cols] returns = self.returns[:-1].reshape(-1, *self.returns.shape[2:]) returns = returns[rows, cols] masks = self.masks[:-1].reshape(-1, *self.masks.shape[2:]) masks = masks[rows, cols] active_masks = self.active_masks[:-1].reshape(-1, *self.active_masks.shape[2:]) active_masks = active_masks[rows, cols] action_log_probs = self.action_log_probs.reshape( -1, *self.action_log_probs.shape[2:] ) action_log_probs = action_log_probs[rows, cols] advantages = advantages.reshape(-1, *advantages.shape[2:]) advantages = advantages[rows, cols] for indices in sampler: # [L,T,N,Dim]-->[L*T,N,Dim]-->[index,N,Dim]-->[index*N, Dim] critic_obs_batch = critic_obs[indices].reshape(-1, *critic_obs.shape[2:]) policy_obs_batch = policy_obs[indices].reshape(-1, *policy_obs.shape[2:]) rnn_states_batch = rnn_states[indices].reshape(-1, *rnn_states.shape[2:]) rnn_states_critic_batch = rnn_states_critic[indices].reshape( -1, *rnn_states_critic.shape[2:] ) actions_batch = actions[indices].reshape(-1, *actions.shape[2:]) if self.action_masks is not None: action_masks_batch = action_masks[indices].reshape( -1, *action_masks.shape[2:] ) else: action_masks_batch = None value_preds_batch = value_preds[indices].reshape(-1, *value_preds.shape[2:]) return_batch = returns[indices].reshape(-1, *returns.shape[2:]) masks_batch = masks[indices].reshape(-1, *masks.shape[2:]) active_masks_batch = active_masks[indices].reshape( -1, *active_masks.shape[2:] ) old_action_log_probs_batch = action_log_probs[indices].reshape( -1, *action_log_probs.shape[2:] ) if advantages is None: adv_targ = None else: adv_targ = advantages[indices].reshape(-1, *advantages.shape[2:]) yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch
[docs] def naive_recurrent_generator(self, advantages, num_mini_batch): episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] batch_size = n_rollout_threads * num_agents assert n_rollout_threads * num_agents >= num_mini_batch, ( "PPO requires the number of processes ({})* number of agents ({}) " "to be greater than or equal to the number of " "PPO mini batches ({}).".format( n_rollout_threads, num_agents, num_mini_batch ) ) num_envs_per_batch = batch_size // num_mini_batch perm = torch.randperm(batch_size).numpy() if self._mixed_obs: critic_obs = {} policy_obs = {} for key in self.critic_obs.keys(): critic_obs[key] = self.critic_obs[key].reshape( -1, batch_size, *self.critic_obs[key].shape[3:] ) for key in self.policy_obs.keys(): policy_obs[key] = self.policy_obs[key].reshape( -1, batch_size, *self.policy_obs[key].shape[3:] ) else: critic_obs = self.critic_obs.reshape( -1, batch_size, *self.critic_obs.shape[3:] ) policy_obs = self.policy_obs.reshape( -1, batch_size, *self.policy_obs.shape[3:] ) rnn_states = self.rnn_states.reshape(-1, batch_size, *self.rnn_states.shape[3:]) rnn_states_critic = self.rnn_states_critic.reshape( -1, batch_size, *self.rnn_states_critic.shape[3:] ) actions = self.actions.reshape(-1, batch_size, self.actions.shape[-1]) if self.action_masks is not None: action_masks = self.action_masks.reshape( -1, batch_size, self.action_masks.shape[-1] ) value_preds = self.value_preds.reshape(-1, batch_size, 1) returns = self.returns.reshape(-1, batch_size, 1) masks = self.masks.reshape(-1, batch_size, 1) active_masks = self.active_masks.reshape(-1, batch_size, 1) action_log_probs = self.action_log_probs.reshape( -1, batch_size, self.action_log_probs.shape[-1] ) advantages = advantages.reshape(-1, batch_size, 1) for start_ind in range(0, batch_size, num_envs_per_batch): if self._mixed_obs: critic_obs_batch = defaultdict(list) policy_obs_batch = defaultdict(list) else: critic_obs_batch = [] policy_obs_batch = [] rnn_states_batch = [] rnn_states_critic_batch = [] actions_batch = [] action_masks_batch = [] value_preds_batch = [] return_batch = [] masks_batch = [] active_masks_batch = [] old_action_log_probs_batch = [] adv_targ = [] for offset in range(num_envs_per_batch): ind = perm[start_ind + offset] if self._mixed_obs: for key in critic_obs.keys(): critic_obs_batch[key].append(critic_obs[key][:-1, ind]) for key in policy_obs.keys(): policy_obs_batch[key].append(policy_obs[key][:-1, ind]) else: critic_obs_batch.append(critic_obs[:-1, ind]) policy_obs_batch.append(policy_obs[:-1, ind]) rnn_states_batch.append(rnn_states[0:1, ind]) rnn_states_critic_batch.append(rnn_states_critic[0:1, ind]) actions_batch.append(actions[:, ind]) if self.action_masks is not None: action_masks_batch.append(action_masks[:-1, ind]) value_preds_batch.append(value_preds[:-1, ind]) return_batch.append(returns[:-1, ind]) masks_batch.append(masks[:-1, ind]) active_masks_batch.append(active_masks[:-1, ind]) old_action_log_probs_batch.append(action_log_probs[:, ind]) adv_targ.append(advantages[:, ind]) # [N[T, dim]] T, N = self.episode_length, num_envs_per_batch # These are all from_numpys of size (T, N, -1) if self._mixed_obs: for key in critic_obs_batch.keys(): critic_obs_batch[key] = np.stack(critic_obs_batch[key], 1) for key in policy_obs_batch.keys(): policy_obs_batch[key] = np.stack(policy_obs_batch[key], 1) else: critic_obs_batch = np.stack(critic_obs_batch, 1) policy_obs_batch = np.stack(policy_obs_batch, 1) actions_batch = np.stack(actions_batch, 1) if self.action_masks is not None: action_masks_batch = np.stack(action_masks_batch, 1) value_preds_batch = np.stack(value_preds_batch, 1) return_batch = np.stack(return_batch, 1) masks_batch = np.stack(masks_batch, 1) active_masks_batch = np.stack(active_masks_batch, 1) old_action_log_probs_batch = np.stack(old_action_log_probs_batch, 1) adv_targ = np.stack(adv_targ, 1) # States is just a (N, dim) from_numpy [N[1,dim]] rnn_states_batch = np.stack(rnn_states_batch).reshape( N, *self.rnn_states.shape[3:] ) rnn_states_critic_batch = np.stack(rnn_states_critic_batch).reshape( N, *self.rnn_states_critic.shape[3:] ) # Flatten the (T, N, ...) from_numpys to (T * N, ...) if self._mixed_obs: for key in critic_obs_batch.keys(): critic_obs_batch[key] = _flatten(T, N, critic_obs_batch[key]) for key in policy_obs_batch.keys(): policy_obs_batch[key] = _flatten(T, N, policy_obs_batch[key]) else: critic_obs_batch = _flatten(T, N, critic_obs_batch) policy_obs_batch = _flatten(T, N, policy_obs_batch) actions_batch = _flatten(T, N, actions_batch) if self.action_masks is not None: action_masks_batch = _flatten(T, N, action_masks_batch) else: action_masks_batch = None value_preds_batch = _flatten(T, N, value_preds_batch) return_batch = _flatten(T, N, return_batch) masks_batch = _flatten(T, N, masks_batch) active_masks_batch = _flatten(T, N, active_masks_batch) old_action_log_probs_batch = _flatten(T, N, old_action_log_probs_batch) adv_targ = _flatten(T, N, adv_targ) yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch
# def recurrent_generator_v2( # self, advantages, num_mini_batch=None, mini_batch_size=None # ): # """ # Yield training data for MLP policies. # :param advantages: (np.ndarray) advantage estimates. # :param num_mini_batch: (int) number of minibatches to split the batch into. # :param mini_batch_size: (int) number of samples in each minibatch. # """ # episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] # batch_size = n_rollout_threads * episode_length # # if mini_batch_size is None: # assert ( # batch_size >= num_mini_batch # ), ( # "PPO requires the number of processes ({}) " # "* number of steps ({}) = {} " # "to be greater than or equal to the number of PPO mini batches ({})." # "".format( # n_rollout_threads, # episode_length, # n_rollout_threads * episode_length, # num_mini_batch, # ) # ) # mini_batch_size = batch_size // num_mini_batch # # rand = torch.randperm(batch_size).numpy() # sampler = [ # rand[i * mini_batch_size : (i + 1) * mini_batch_size] # for i in range(num_mini_batch) # ] # # # keep (num_agent, dim) # critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[2:]) # # policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[2:]) # # rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[2:]) # # rnn_states_critic = self.rnn_states_critic[:-1].reshape( # -1, *self.rnn_states_critic.shape[2:] # ) # # actions = self.actions.reshape(-1, *self.actions.shape[2:]) # # if self.action_masks is not None: # action_masks = self.action_masks[:-1].reshape( # -1, *self.action_masks.shape[2:] # ) # # value_preds = self.value_preds[:-1].reshape(-1, *self.value_preds.shape[2:]) # # returns = self.returns[:-1].reshape(-1, *self.returns.shape[2:]) # # masks = self.masks[:-1].reshape(-1, *self.masks.shape[2:]) # # active_masks = self.active_masks[:-1].reshape(-1, *self.active_masks.shape[2:]) # # action_log_probs = self.action_log_probs.reshape( # -1, *self.action_log_probs.shape[2:] # ) # # advantages = advantages.reshape(-1, *advantages.shape[2:]) # # shuffle = False # if shuffle: # rows, cols = _shuffle_agent_grid(batch_size, num_agents) # # if self.action_masks is not None: # action_masks = action_masks[rows, cols] # critic_obs = critic_obs[rows, cols] # policy_obs = policy_obs[rows, cols] # rnn_states = rnn_states[rows, cols] # rnn_states_critic = rnn_states_critic[rows, cols] # actions = actions[rows, cols] # value_preds = value_preds[rows, cols] # returns = returns[rows, cols] # masks = masks[rows, cols] # active_masks = active_masks[rows, cols] # action_log_probs = action_log_probs[rows, cols] # advantages = advantages[rows, cols] # # for indices in sampler: # # [L,T,N,Dim]-->[L*T,N,Dim]-->[index,N,Dim]-->[index*N, Dim] # critic_obs_batch = critic_obs[indices].reshape(-1, *critic_obs.shape[2:]) # policy_obs_batch = policy_obs[indices].reshape(-1, *policy_obs.shape[2:]) # rnn_states_batch = rnn_states[indices].reshape(-1, *rnn_states.shape[2:]) # rnn_states_critic_batch = rnn_states_critic[indices].reshape( # -1, *rnn_states_critic.shape[2:] # ) # actions_batch = actions[indices].reshape(-1, *actions.shape[2:]) # if self.action_masks is not None: # action_masks_batch = action_masks[indices].reshape( # -1, *action_masks.shape[2:] # ) # else: # action_masks_batch = None # value_preds_batch = value_preds[indices].reshape(-1, *value_preds.shape[2:]) # return_batch = returns[indices].reshape(-1, *returns.shape[2:]) # masks_batch = masks[indices].reshape(-1, *masks.shape[2:]) # active_masks_batch = active_masks[indices].reshape( # -1, *active_masks.shape[2:] # ) # old_action_log_probs_batch = action_log_probs[indices].reshape( # -1, *action_log_probs.shape[2:] # ) # if advantages is None: # adv_targ = None # else: # adv_targ = advantages[indices].reshape(-1, *advantages.shape[2:]) # yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch
[docs] def recurrent_generator(self, advantages, num_mini_batch, data_chunk_length): episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] batch_size = n_rollout_threads * episode_length * num_agents data_chunks = batch_size // data_chunk_length # [C=r*T*M/L] mini_batch_size = data_chunks // num_mini_batch assert n_rollout_threads * episode_length * num_agents >= data_chunk_length, ( "PPO requires the number of processes ({})* number of agents ({}) * episode" " length ({}) to be greater than or equal to the number of data chunk" " length ({}).".format( n_rollout_threads, num_agents, episode_length, data_chunk_length ) ) rand = torch.randperm(data_chunks).numpy() sampler = [ rand[i * mini_batch_size : (i + 1) * mini_batch_size] for i in range(num_mini_batch) ] if self._mixed_obs: critic_obs = {} policy_obs = {} for key in self.critic_obs.keys(): if len(self.critic_obs[key].shape) == 6: critic_obs[key] = ( self.critic_obs[key][:-1] .transpose(1, 2, 0, 3, 4, 5) .reshape(-1, *self.critic_obs[key].shape[3:]) ) elif len(self.critic_obs[key].shape) == 5: critic_obs[key] = ( self.critic_obs[key][:-1] .transpose(1, 2, 0, 3, 4) .reshape(-1, *self.critic_obs[key].shape[3:]) ) else: critic_obs[key] = _cast(self.critic_obs[key][:-1]) for key in self.policy_obs.keys(): if len(self.policy_obs[key].shape) == 6: policy_obs[key] = ( self.policy_obs[key][:-1] .transpose(1, 2, 0, 3, 4, 5) .reshape(-1, *self.policy_obs[key].shape[3:]) ) elif len(self.policy_obs[key].shape) == 5: policy_obs[key] = ( self.policy_obs[key][:-1] .transpose(1, 2, 0, 3, 4) .reshape(-1, *self.policy_obs[key].shape[3:]) ) else: policy_obs[key] = _cast(self.policy_obs[key][:-1]) else: if len(self.critic_obs.shape) > 4: critic_obs = ( self.critic_obs[:-1] .transpose(1, 2, 0, 3, 4, 5) .reshape(-1, *self.critic_obs.shape[3:]) ) policy_obs = ( self.policy_obs[:-1] .transpose(1, 2, 0, 3, 4, 5) .reshape(-1, *self.policy_obs.shape[3:]) ) else: critic_obs = _cast(self.critic_obs[:-1]) policy_obs = _cast(self.policy_obs[:-1]) actions = _cast(self.actions) action_log_probs = _cast(self.action_log_probs) advantages = _cast(advantages) value_preds = _cast(self.value_preds[:-1]) returns = _cast(self.returns[:-1]) masks = _cast(self.masks[:-1]) active_masks = _cast(self.active_masks[:-1]) rnn_states = ( self.rnn_states[:-1] .transpose(1, 2, 0, 3, 4) .reshape(-1, *self.rnn_states.shape[3:]) ) rnn_states_critic = ( self.rnn_states_critic[:-1] .transpose(1, 2, 0, 3, 4) .reshape(-1, *self.rnn_states_critic.shape[3:]) ) if self.action_masks is not None: action_masks = _cast(self.action_masks[:-1]) for indices in sampler: if self._mixed_obs: critic_obs_batch = defaultdict(list) policy_obs_batch = defaultdict(list) else: critic_obs_batch = [] policy_obs_batch = [] rnn_states_batch = [] rnn_states_critic_batch = [] actions_batch = [] action_masks_batch = [] value_preds_batch = [] return_batch = [] masks_batch = [] active_masks_batch = [] old_action_log_probs_batch = [] adv_targ = [] for index in indices: ind = index * data_chunk_length # size [T+1 N M Dim]-->[T N M Dim]-->[N,M,T,Dim]-->[N*M*T,Dim]-->[L,Dim] if self._mixed_obs: for key in critic_obs.keys(): critic_obs_batch[key].append( critic_obs[key][ind : ind + data_chunk_length] ) for key in policy_obs.keys(): policy_obs_batch[key].append( policy_obs[key][ind : ind + data_chunk_length] ) else: critic_obs_batch.append(critic_obs[ind : ind + data_chunk_length]) policy_obs_batch.append(policy_obs[ind : ind + data_chunk_length]) actions_batch.append(actions[ind : ind + data_chunk_length]) if self.action_masks is not None: action_masks_batch.append( action_masks[ind : ind + data_chunk_length] ) value_preds_batch.append(value_preds[ind : ind + data_chunk_length]) return_batch.append(returns[ind : ind + data_chunk_length]) masks_batch.append(masks[ind : ind + data_chunk_length]) active_masks_batch.append(active_masks[ind : ind + data_chunk_length]) old_action_log_probs_batch.append( action_log_probs[ind : ind + data_chunk_length] ) adv_targ.append(advantages[ind : ind + data_chunk_length]) # size [T+1 N M Dim]-->[T N M Dim]-->[N M T Dim]-->[N*M*T,Dim]-->[1,Dim] rnn_states_batch.append(rnn_states[ind]) rnn_states_critic_batch.append(rnn_states_critic[ind]) L, N = data_chunk_length, mini_batch_size # These are all from_numpys of size (L, N, Dim) if self._mixed_obs: for key in critic_obs_batch.keys(): critic_obs_batch[key] = np.stack(critic_obs_batch[key], axis=1) for key in policy_obs_batch.keys(): policy_obs_batch[key] = np.stack(policy_obs_batch[key], axis=1) else: critic_obs_batch = np.stack(critic_obs_batch, axis=1) policy_obs_batch = np.stack(policy_obs_batch, axis=1) actions_batch = np.stack(actions_batch, axis=1) if self.action_masks is not None: action_masks_batch = np.stack(action_masks_batch, axis=1) value_preds_batch = np.stack(value_preds_batch, axis=1) return_batch = np.stack(return_batch, axis=1) masks_batch = np.stack(masks_batch, axis=1) active_masks_batch = np.stack(active_masks_batch, axis=1) old_action_log_probs_batch = np.stack(old_action_log_probs_batch, axis=1) adv_targ = np.stack(adv_targ, axis=1) # States is just a (N, -1) from_numpy rnn_states_batch = np.stack(rnn_states_batch).reshape( N, *self.rnn_states.shape[3:] ) rnn_states_critic_batch = np.stack(rnn_states_critic_batch).reshape( N, *self.rnn_states_critic.shape[3:] ) # Flatten the (L, N, ...) from_numpys to (L * N, ...) if self._mixed_obs: for key in critic_obs_batch.keys(): critic_obs_batch[key] = _flatten(L, N, critic_obs_batch[key]) for key in policy_obs_batch.keys(): policy_obs_batch[key] = _flatten(L, N, policy_obs_batch[key]) else: critic_obs_batch = _flatten(L, N, critic_obs_batch) policy_obs_batch = _flatten(L, N, policy_obs_batch) actions_batch = _flatten(L, N, actions_batch) if self.action_masks is not None: action_masks_batch = _flatten(L, N, action_masks_batch) else: action_masks_batch = None value_preds_batch = _flatten(L, N, value_preds_batch) return_batch = _flatten(L, N, return_batch) masks_batch = _flatten(L, N, masks_batch) active_masks_batch = _flatten(L, N, active_masks_batch) old_action_log_probs_batch = _flatten(L, N, old_action_log_probs_batch) adv_targ = _flatten(L, N, adv_targ) yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch