Source code for openrl.algorithms.a2c
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
from typing import Union
import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel
from openrl.algorithms.ppo import PPOAlgorithm
[docs]class A2CAlgorithm(PPOAlgorithm):
def __init__(
self,
cfg,
init_module,
agent_num: int = 1,
device: Union[str, torch.device] = "cpu",
) -> None:
super(A2CAlgorithm, self).__init__(cfg, init_module, agent_num, device)
self.num_mini_batch = 1
[docs] def prepare_loss(
self,
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
action_masks_batch,
old_action_log_probs_batch,
adv_targ,
value_preds_batch,
return_batch,
active_masks_batch,
turn_on,
):
if self.use_joint_action_loss:
critic_obs_batch = self.to_single_np(critic_obs_batch)
rnn_states_critic_batch = self.to_single_np(rnn_states_critic_batch)
critic_masks_batch = self.to_single_np(masks_batch)
value_preds_batch = self.to_single_np(value_preds_batch)
return_batch = self.to_single_np(return_batch)
adv_targ = adv_targ.reshape(-1, self.agent_num, 1)
adv_targ = adv_targ[:, 0, :]
else:
critic_masks_batch = masks_batch
(
values,
action_log_probs,
dist_entropy,
policy_values,
) = self.algo_module.evaluate_actions(
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
action_masks_batch,
active_masks_batch,
critic_masks_batch=critic_masks_batch,
)
if self.use_joint_action_loss:
active_masks_batch = active_masks_batch.reshape(-1, self.agent_num, 1)
active_masks_batch = active_masks_batch[:, 0, :]
policy_gradient_loss = -adv_targ.detach() * action_log_probs
if self._use_policy_active_masks:
policy_action_loss = (
torch.sum(policy_gradient_loss, dim=-1, keepdim=True)
* active_masks_batch
).sum() / active_masks_batch.sum()
else:
policy_action_loss = torch.sum(
policy_gradient_loss, dim=-1, keepdim=True
).mean()
if self._use_policy_vhead:
if isinstance(self.algo_module.models["actor"], DistributedDataParallel):
policy_value_normalizer = self.algo_module.models[
"actor"
].module.value_normalizer
else:
policy_value_normalizer = self.algo_module.models[
"actor"
].value_normalizer
policy_value_loss = self.cal_value_loss(
policy_value_normalizer,
policy_values,
value_preds_batch,
return_batch,
active_masks_batch,
)
policy_loss = (
policy_action_loss + policy_value_loss * self.policy_value_loss_coef
)
else:
policy_loss = policy_action_loss
# critic update
if self._use_share_model:
value_normalizer = self.algo_module.models["model"].value_normalizer
elif isinstance(self.algo_module.models["critic"], DistributedDataParallel):
value_normalizer = self.algo_module.models["critic"].module.value_normalizer
else:
value_normalizer = self.algo_module.get_critic_value_normalizer()
value_loss = self.cal_value_loss(
value_normalizer,
values,
value_preds_batch,
return_batch,
active_masks_batch,
)
loss_list = self.construct_loss_list(
policy_loss, dist_entropy, value_loss, turn_on
)
ratio = np.zeros(1)
return loss_list, value_loss, policy_loss, dist_entropy, ratio
[docs] def train(self, buffer, turn_on: bool = True):
train_info = super(A2CAlgorithm, self).train(buffer, turn_on)
train_info.pop("ratio", None)
return train_info