Source code for openrl.algorithms.vdn
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
from typing import Union
import torch
import torch.nn.functional as F
from openrl.algorithms.base_algorithm import BaseAlgorithm
from openrl.modules.networks.utils.distributed_utils import reduce_tensor
from openrl.modules.utils.util import get_grad_norm, huber_loss, mse_loss
from openrl.utils.util import check
[docs]class VDNAlgorithm(BaseAlgorithm):
def __init__(
self,
cfg,
init_module,
agent_num: int = 1,
device: Union[str, torch.device] = "cpu",
) -> None:
super(VDNAlgorithm, self).__init__(cfg, init_module, agent_num, device)
self.gamma = cfg.gamma
self.n_agent = cfg.num_agents
self.update_count = 0
self.target_update_frequency = cfg.train_interval
[docs] def vdn_update(self, sample, turn_on=True):
for optimizer in self.algo_module.optimizers.values():
optimizer.zero_grad()
(
obs_batch,
_,
next_obs_batch,
_,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
value_preds_batch,
rewards_batch,
masks_batch,
next_masks_batch,
active_masks_batch,
old_action_log_probs_batch,
adv_targ,
action_masks_batch,
) = sample
value_preds_batch = check(value_preds_batch).to(**self.tpdv)
rewards_batch = check(rewards_batch).to(**self.tpdv)
active_masks_batch = check(active_masks_batch).to(**self.tpdv)
next_masks_batch = check(next_masks_batch).to(**self.tpdv)
if self.use_amp:
with torch.cuda.amp.autocast():
loss_list = self.prepare_loss(
obs_batch,
next_obs_batch,
rnn_states_batch,
actions_batch,
masks_batch,
next_masks_batch,
action_masks_batch,
value_preds_batch,
rewards_batch,
active_masks_batch,
turn_on,
)
for loss in loss_list:
self.algo_module.scaler.scale(loss).backward()
else:
loss_list = self.prepare_loss(
obs_batch,
next_obs_batch,
rnn_states_batch,
actions_batch,
masks_batch,
next_masks_batch,
action_masks_batch,
value_preds_batch,
rewards_batch,
active_masks_batch,
turn_on,
)
for loss in loss_list:
loss.backward()
if "transformer" in self.algo_module.models:
raise NotImplementedError
else:
actor_para = self.algo_module.models["vdn_net"].parameters()
actor_grad_norm = get_grad_norm(actor_para)
if self.use_amp:
for optimizer in self.algo_module.optimizers.values():
self.algo_module.scaler.unscale_(optimizer)
for optimizer in self.algo_module.optimizers.values():
self.algo_module.scaler.step(optimizer)
self.algo_module.scaler.update()
else:
for optimizer in self.algo_module.optimizers.values():
optimizer.step()
if self.world_size > 1:
torch.cuda.synchronize()
if self.update_count % self.target_update_frequency == 0:
self.update_count = 0
self.algo_module.models["target_vdn_net"].load_state_dict(
self.algo_module.models["vdn_net"].state_dict()
)
return loss
[docs] def cal_value_loss(
self,
value_normalizer,
values,
value_preds_batch,
return_batch,
active_masks_batch,
):
value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(
-self.clip_param, self.clip_param
)
if self._use_popart or self._use_valuenorm:
value_normalizer.update(return_batch)
error_clipped = (
value_normalizer.normalize(return_batch) - value_pred_clipped
)
error_original = value_normalizer.normalize(return_batch) - values
else:
error_clipped = return_batch - value_pred_clipped
error_original = return_batch - values
if self._use_huber_loss:
value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
value_loss_original = huber_loss(error_original, self.huber_delta)
else:
value_loss_clipped = mse_loss(error_clipped)
value_loss_original = mse_loss(error_original)
if self._use_clipped_value_loss:
value_loss = torch.max(value_loss_original, value_loss_clipped)
else:
value_loss = value_loss_original
if self._use_value_active_masks:
value_loss = (
value_loss * active_masks_batch
).sum() / active_masks_batch.sum()
else:
value_loss = value_loss.mean()
return value_loss
[docs] def to_single_np(self, input):
reshape_input = input.reshape(-1, self.agent_num, *input.shape[1:])
return reshape_input[:, 0, ...]
[docs] def prepare_loss(
self,
obs_batch,
next_obs_batch,
rnn_states_batch,
actions_batch,
masks_batch,
next_masks_batch,
action_masks_batch,
value_preds_batch,
rewards_batch,
active_masks_batch,
turn_on,
):
loss_list = []
critic_masks_batch = masks_batch
(q_values, max_next_q_values) = self.algo_module.evaluate_actions(
obs_batch,
next_obs_batch,
rnn_states_batch,
rewards_batch,
actions_batch,
masks_batch,
action_masks_batch,
active_masks_batch,
critic_masks_batch=critic_masks_batch,
)
rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1)
rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1)
q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
q_loss = torch.mean(
F.mse_loss(q_values, q_targets.detach())
) # 均方误差损失函数
loss_list.append(q_loss)
return loss_list
[docs] def train(self, buffer, turn_on=True):
train_info = {}
train_info["q_loss"] = 0
if self.world_size > 1:
train_info["reduced_q_loss"] = 0
# todo add rnn and transformer
for _ in range(self.num_mini_batch):
if "transformer" in self.algo_module.models:
raise NotImplementedError
elif self._use_recurrent_policy:
raise NotImplementedError
elif self._use_naive_recurrent:
raise NotImplementedError
else:
data_generator = buffer.feed_forward_generator(
None,
num_mini_batch=self.num_mini_batch,
)
for sample in data_generator:
(q_loss) = self.vdn_update(sample, turn_on)
if self.world_size > 1:
train_info["reduced_q_loss"] += reduce_tensor(
q_loss.data, self.world_size
)
train_info["q_loss"] += q_loss.item()
num_updates = 1 * self.num_mini_batch
for k in train_info.keys():
train_info[k] /= num_updates
for optimizer in self.algo_module.optimizers.values():
if hasattr(optimizer, "sync_lookahead"):
optimizer.sync_lookahead()
return train_info