Source code for openrl.algorithms.behavior_cloning
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
from typing import Union
import torch
import torch.nn as nn
from openrl.algorithms.base_algorithm import BaseAlgorithm
from openrl.modules.networks.utils.distributed_utils import reduce_tensor
from openrl.modules.utils.util import get_grad_norm
from openrl.utils.util import check
[docs]class BCAlgorithm(BaseAlgorithm):
def __init__(
self,
cfg,
init_module,
agent_num: int = 1,
device: Union[str, torch.device] = "cpu",
) -> None:
self._use_share_model = cfg.use_share_model
self.use_joint_action_loss = cfg.use_joint_action_loss
super(BCAlgorithm, self).__init__(cfg, init_module, agent_num, device)
self.train_list = [self.train_bc]
[docs] def bc_update(self, sample, turn_on=True):
for optimizer in self.algo_module.optimizers.values():
optimizer.zero_grad()
(
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
value_preds_batch,
return_batch,
masks_batch,
active_masks_batch,
old_action_log_probs_batch,
adv_targ,
action_masks_batch,
) = sample
old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
return_batch = check(return_batch).to(**self.tpdv)
active_masks_batch = check(active_masks_batch).to(**self.tpdv)
if self.use_amp:
with torch.cuda.amp.autocast():
(
loss_list,
value_loss,
policy_loss,
dist_entropy,
ratio,
) = self.prepare_loss(
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
action_masks_batch,
old_action_log_probs_batch,
adv_targ,
value_preds_batch,
return_batch,
active_masks_batch,
turn_on,
)
for loss in loss_list:
self.algo_module.scaler.scale(loss).backward()
else:
loss_list, value_loss, policy_loss, dist_entropy, ratio = self.prepare_loss(
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
action_masks_batch,
old_action_log_probs_batch,
adv_targ,
value_preds_batch,
return_batch,
active_masks_batch,
turn_on,
)
for loss in loss_list:
loss.backward()
# else:
actor_para = self.algo_module.models["policy"].parameters()
if self._use_max_grad_norm:
actor_grad_norm = nn.utils.clip_grad_norm_(actor_para, self.max_grad_norm)
else:
actor_grad_norm = get_grad_norm(actor_para)
if self.use_amp:
for optimizer in self.algo_module.optimizers.values():
self.algo_module.scaler.unscale_(optimizer)
for optimizer in self.algo_module.optimizers.values():
self.algo_module.scaler.step(optimizer)
self.algo_module.scaler.update()
else:
for optimizer in self.algo_module.optimizers.values():
optimizer.step()
if self.world_size > 1:
torch.cuda.synchronize()
critic_grad_norm = None
return (
value_loss,
critic_grad_norm,
policy_loss,
dist_entropy,
actor_grad_norm,
ratio,
)
[docs] def to_single_np(self, input):
reshape_input = input.reshape(-1, self.agent_num, *input.shape[1:])
return reshape_input[:, 0, ...]
[docs] def construct_loss_list(self, policy_loss, dist_entropy, value_loss, turn_on):
loss_list = []
if turn_on:
final_p_loss = policy_loss - dist_entropy * self.entropy_coef
loss_list.append(final_p_loss)
return loss_list
[docs] def prepare_loss(
self,
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
action_masks_batch,
old_action_log_probs_batch,
adv_targ,
value_preds_batch,
return_batch,
active_masks_batch,
turn_on,
):
(
values,
action_log_probs,
dist_entropy,
policy_values,
) = self.algo_module.evaluate_actions(
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
action_masks_batch,
active_masks_batch,
critic_masks_batch=None,
)
if self.use_joint_action_loss:
action_log_probs_copy = (
action_log_probs.reshape(-1, self.agent_num, action_log_probs.shape[-1])
.sum(dim=(1, -1), keepdim=True)
.reshape(-1, 1)
)
policy_loss = -action_log_probs_copy.mean()
else:
policy_loss = -action_log_probs.mean()
value_loss = None
ratio = None
loss_list = self.construct_loss_list(
policy_loss, dist_entropy, value_loss, turn_on
)
return loss_list, value_loss, policy_loss, dist_entropy, ratio
[docs] def get_data_generator(self, buffer):
advantages = None
if self._use_recurrent_policy:
if self.use_joint_action_loss:
data_generator = buffer.recurrent_generator_v3(
advantages, self.num_mini_batch, self.data_chunk_length
)
else:
data_generator = buffer.recurrent_generator(
advantages, self.num_mini_batch, self.data_chunk_length
)
elif self._use_naive_recurrent:
data_generator = buffer.naive_recurrent_generator(
advantages, self.num_mini_batch
)
else:
data_generator = buffer.feed_forward_generator(
advantages, self.num_mini_batch
)
return data_generator
[docs] def train_bc(self, buffer, turn_on):
train_info = {}
train_info["policy_loss"] = 0
train_info["dist_entropy"] = 0
train_info["actor_grad_norm"] = 0
if self.world_size > 1:
train_info["reduced_value_loss"] = 0
train_info["reduced_policy_loss"] = 0
for _ in range(self.bc_epoch):
data_generator = self.get_data_generator(buffer)
for sample in data_generator:
(
value_loss,
critic_grad_norm,
policy_loss,
dist_entropy,
actor_grad_norm,
ratio,
) = self.bc_update(sample, turn_on)
if self.world_size > 1:
train_info["reduced_policy_loss"] += reduce_tensor(
policy_loss.data, self.world_size
)
train_info["policy_loss"] += policy_loss.item()
train_info["dist_entropy"] += dist_entropy.item()
train_info["actor_grad_norm"] += actor_grad_norm
num_updates = self.bc_epoch * self.num_mini_batch
for k in train_info.keys():
train_info[k] /= num_updates
return train_info
[docs] def train(self, buffer, turn_on=True):
train_info = {}
for train_func in self.train_list:
train_info.update(train_func(buffer, turn_on))
for optimizer in self.algo_module.optimizers.values():
if hasattr(optimizer, "sync_lookahead"):
optimizer.sync_lookahead()
return train_info