Shortcuts

Source code for openrl.drivers.onpolicy_driver

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel

from openrl.drivers.rl_driver import RLDriver
from openrl.envs.vec_env.utils.util import prepare_available_actions
from openrl.runners.common.base_agent import BaseAgent
from openrl.utils.logger import Logger
from openrl.utils.type_aliases import MaybeCallback
from openrl.utils.util import _t2n


[docs]class OnPolicyDriver(RLDriver): def __init__( self, config: Dict[str, Any], trainer, buffer, agent, rank: int = 0, world_size: int = 1, client=None, logger: Optional[Logger] = None, callback: MaybeCallback = None, ) -> None: super(OnPolicyDriver, self).__init__( config, trainer, buffer, agent, rank, world_size, client, logger, callback=callback, ) def _inner_loop( self, ) -> bool: """ :return: True if training should continue, False if training should stop """ rollout_infos, continue_training = self.actor_rollout() if not continue_training: return False train_infos = self.learner_update() self.buffer.after_update() self.total_num_steps = ( (self.episode + 1) * self.episode_length * self.n_rollout_threads ) if self.episode % self.log_interval == 0: # rollout_infos can only be used when env is wrapped with VevMonitor self.logger.log_info(rollout_infos, step=self.total_num_steps) self.logger.log_info(train_infos, step=self.total_num_steps) return True
[docs] def add2buffer(self, data): ( obs, rewards, dones, infos, values, actions, action_log_probs, rnn_states, rnn_states_critic, ) = data dones_env = np.all(dones, axis=1) if rnn_states is not None: rnn_states[dones_env] = np.zeros( (dones_env.sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32, ) if rnn_states_critic is not None: rnn_states_critic[dones_env] = np.zeros( ( dones_env.sum(), self.num_agents, *self.buffer.data.rnn_states_critic.shape[3:], ), dtype=np.float32, ) masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32) masks[dones_env] = np.zeros( (dones_env.sum(), self.num_agents, 1), dtype=np.float32 ) available_actions = prepare_available_actions( infos, agent_num=self.num_agents, as_batch=False ) active_masks = np.ones( (self.n_rollout_threads, self.num_agents, 1), dtype=np.float32 ) active_masks[dones] = np.zeros((dones.sum(), 1), dtype=np.float32) active_masks[dones_env] = np.ones( (dones_env.sum(), self.num_agents, 1), dtype=np.float32 ) bad_masks = np.array( [ [ ( [0.0] if "bad_transition" in info and info["bad_transition"][agent_id] else [1.0] ) for agent_id in range(self.num_agents) ] for info in infos ] ) self.buffer.insert( obs, rnn_states, rnn_states_critic, actions, action_log_probs, values, rewards, masks, active_masks=active_masks, bad_masks=bad_masks, available_actions=available_actions, )
[docs] def actor_rollout(self) -> Tuple[Dict[str, Any], bool]: self.callback.on_rollout_start() self.trainer.prep_rollout() import time for step in range(self.episode_length): values, actions, action_log_probs, rnn_states, rnn_states_critic = self.act( step ) extra_data = { "values": values, "action_log_probs": action_log_probs, "step": step, "buffer": self.buffer, } obs, rewards, dones, infos = self.envs.step(actions, extra_data) self.agent.num_time_steps += self.envs.parallel_env_num # Give access to local variables self.callback.update_locals(locals()) if self.callback.on_step() is False: return {}, False data = ( obs, rewards, dones, infos, values, actions, action_log_probs, rnn_states, rnn_states_critic, ) self.add2buffer(data) batch_rew_infos = self.envs.batch_rewards(self.buffer) self.callback.on_rollout_end() if self.envs.use_monitor: statistics_info = self.envs.statistics(self.buffer) statistics_info.update(batch_rew_infos) return statistics_info, True else: return batch_rew_infos, True
[docs] @torch.no_grad() def compute_returns(self): self.trainer.prep_rollout() next_values = self.trainer.algo_module.get_values( self.buffer.data.get_batch_data("critic_obs", -1), np.concatenate(self.buffer.data.rnn_states_critic[-1]), np.concatenate(self.buffer.data.masks[-1]), ) next_values = np.array( np.split(_t2n(next_values), self.learner_n_rollout_threads) ) if "critic" in self.trainer.algo_module.models and isinstance( self.trainer.algo_module.models["critic"], DistributedDataParallel ): value_normalizer = self.trainer.algo_module.models[ "critic" ].module.value_normalizer elif "model" in self.trainer.algo_module.models and isinstance( self.trainer.algo_module.models["model"], DistributedDataParallel ): value_normalizer = self.trainer.algo_module.models["model"].value_normalizer else: value_normalizer = self.trainer.algo_module.get_critic_value_normalizer() self.buffer.compute_returns(next_values, value_normalizer)
[docs] @torch.no_grad() def act( self, step: int, ): self.trainer.prep_rollout() ( value, action, action_log_prob, rnn_states, rnn_states_critic, ) = self.trainer.algo_module.get_actions( self.buffer.data.get_batch_data("critic_obs", step), self.buffer.data.get_batch_data("policy_obs", step), # np.concatenate(self.buffer.data.rnn_states[step]), # np.concatenate(self.buffer.data.rnn_states_critic[step]), # np.concatenate(self.buffer.data.masks[step]), self.buffer.data.get_batch_data("rnn_states", step), self.buffer.data.get_batch_data("rnn_states_critic", step), self.buffer.data.get_batch_data("masks", step), available_actions=self.buffer.data.get_batch_data( "available_actions", step ), ) values = np.array(np.split(_t2n(value), self.n_rollout_threads)) actions = np.array(np.split(_t2n(action), self.n_rollout_threads)) action_log_probs = np.array( np.split(_t2n(action_log_prob), self.n_rollout_threads) ) if rnn_states is not None: rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads)) if rnn_states_critic is not None: rnn_states_critic = np.array( np.split(_t2n(rnn_states_critic), self.n_rollout_threads) ) return ( values, actions, action_log_probs, rnn_states, rnn_states_critic, )