Source code for openrl.algorithms.ppo
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
from typing import Union
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from openrl.algorithms.base_algorithm import BaseAlgorithm
from openrl.modules.networks.utils.distributed_utils import reduce_tensor
from openrl.modules.utils.util import get_grad_norm, huber_loss, mse_loss
from openrl.utils.util import check
[docs]class PPOAlgorithm(BaseAlgorithm):
def __init__(
self,
cfg,
init_module,
agent_num: int = 1,
device: Union[str, torch.device] = "cpu",
) -> None:
self._use_share_model = cfg.use_share_model
self.use_joint_action_loss = cfg.use_joint_action_loss
super(PPOAlgorithm, self).__init__(cfg, init_module, agent_num, device)
self.train_list = [self.train_ppo]
self.use_deepspeed = cfg.use_deepspeed
[docs] def ppo_update(self, sample, turn_on=True):
for optimizer in self.algo_module.optimizers.values():
if not self.use_deepspeed:
optimizer.zero_grad()
(
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
value_preds_batch,
return_batch,
masks_batch,
active_masks_batch,
old_action_log_probs_batch,
adv_targ,
action_masks_batch,
) = sample
old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
adv_targ = check(adv_targ).to(**self.tpdv)
value_preds_batch = check(value_preds_batch).to(**self.tpdv)
return_batch = check(return_batch).to(**self.tpdv)
active_masks_batch = check(active_masks_batch).to(**self.tpdv)
if self.use_amp:
with torch.cuda.amp.autocast():
(
loss_list,
value_loss,
policy_loss,
dist_entropy,
ratio,
) = self.prepare_loss(
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
action_masks_batch,
old_action_log_probs_batch,
adv_targ,
value_preds_batch,
return_batch,
active_masks_batch,
turn_on,
)
for loss in loss_list:
self.algo_module.scaler.scale(loss).backward()
else:
loss_list, value_loss, policy_loss, dist_entropy, ratio = self.prepare_loss(
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
action_masks_batch,
old_action_log_probs_batch,
adv_targ,
value_preds_batch,
return_batch,
active_masks_batch,
turn_on,
)
if self.use_deepspeed:
if self._use_share_model:
for loss in loss_list:
self.algo_module.models["model"].backward(loss)
else:
actor_loss = loss_list[0]
critic_loss = loss_list[1]
self.algo_module.models["policy"].backward(actor_loss)
self.algo_module.models["critic"].backward(critic_loss)
else:
for loss in loss_list:
loss.backward()
# else:
if self._use_share_model:
actor_para = self.algo_module.models["model"].get_actor_para()
else:
actor_para = self.algo_module.models["policy"].parameters()
if self._use_max_grad_norm:
actor_grad_norm = nn.utils.clip_grad_norm_(actor_para, self.max_grad_norm)
else:
actor_grad_norm = get_grad_norm(actor_para)
if self._use_share_model:
critic_para = self.algo_module.models["model"].get_critic_para()
else:
critic_para = self.algo_module.models["critic"].parameters()
if self._use_max_grad_norm:
critic_grad_norm = nn.utils.clip_grad_norm_(critic_para, self.max_grad_norm)
else:
critic_grad_norm = get_grad_norm(critic_para)
if self.use_amp:
for optimizer in self.algo_module.optimizers.values():
self.algo_module.scaler.unscale_(optimizer)
for optimizer in self.algo_module.optimizers.values():
self.algo_module.scaler.step(optimizer)
self.algo_module.scaler.update()
else:
if self.use_deepspeed:
if self._use_share_model:
self.algo_module.optimizers["model"].step()
else:
self.algo_module.optimizers["policy"].step()
self.algo_module.optimizers["critic"].step()
else:
for optimizer in self.algo_module.optimizers.values():
optimizer.step()
if self.world_size > 1:
torch.cuda.synchronize()
return (
value_loss,
critic_grad_norm,
policy_loss,
dist_entropy,
actor_grad_norm,
ratio,
)
[docs] def cal_value_loss(
self,
value_normalizer,
values,
value_preds_batch,
return_batch,
active_masks_batch,
):
value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(
-self.clip_param, self.clip_param
)
if (self._use_popart or self._use_valuenorm) and value_normalizer is not None:
value_normalizer.update(return_batch)
error_clipped = (
value_normalizer.normalize(return_batch) - value_pred_clipped
)
error_original = value_normalizer.normalize(return_batch) - values
else:
error_clipped = return_batch - value_pred_clipped
error_original = return_batch - values
if self._use_huber_loss:
value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
value_loss_original = huber_loss(error_original, self.huber_delta)
else:
value_loss_clipped = mse_loss(error_clipped)
value_loss_original = mse_loss(error_original)
if self._use_clipped_value_loss:
value_loss = torch.max(value_loss_original, value_loss_clipped)
else:
value_loss = value_loss_original
if self._use_value_active_masks:
value_loss = (
value_loss * active_masks_batch
).sum() / active_masks_batch.sum()
else:
value_loss = value_loss.mean()
# print(value_loss)
# import pdb;pdb.set_trace()
return value_loss
[docs] def to_single_np(self, input):
reshape_input = input.reshape(-1, self.agent_num, *input.shape[1:])
return reshape_input[:, 0, ...]
[docs] def construct_loss_list(self, policy_loss, dist_entropy, value_loss, turn_on):
loss_list = []
if turn_on:
final_p_loss = policy_loss - dist_entropy * self.entropy_coef
loss_list.append(final_p_loss)
final_v_loss = value_loss * self.value_loss_coef
loss_list.append(final_v_loss)
return loss_list
[docs] def prepare_loss(
self,
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
action_masks_batch,
old_action_log_probs_batch,
adv_targ,
value_preds_batch,
return_batch,
active_masks_batch,
turn_on,
):
if self.use_joint_action_loss:
critic_obs_batch = self.to_single_np(critic_obs_batch)
rnn_states_critic_batch = self.to_single_np(rnn_states_critic_batch)
critic_masks_batch = self.to_single_np(masks_batch)
value_preds_batch = self.to_single_np(value_preds_batch)
return_batch = self.to_single_np(return_batch)
adv_targ = adv_targ.reshape(-1, self.agent_num, 1)
adv_targ = adv_targ[:, 0, :]
else:
critic_masks_batch = masks_batch
(
values,
action_log_probs,
dist_entropy,
policy_values,
) = self.algo_module.evaluate_actions(
critic_obs_batch,
obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
masks_batch,
action_masks_batch,
active_masks_batch,
critic_masks_batch=critic_masks_batch,
)
if self.use_joint_action_loss:
action_log_probs_copy = (
action_log_probs.reshape(-1, self.agent_num, action_log_probs.shape[-1])
.sum(dim=(1, -1), keepdim=True)
.reshape(-1, 1)
)
old_action_log_probs_batch_copy = (
old_action_log_probs_batch.reshape(
-1, self.agent_num, old_action_log_probs_batch.shape[-1]
)
.sum(dim=(1, -1), keepdim=True)
.reshape(-1, 1)
)
active_masks_batch = active_masks_batch.reshape(-1, self.agent_num, 1)
active_masks_batch = active_masks_batch[:, 0, :]
ratio = torch.exp(action_log_probs_copy - old_action_log_probs_batch_copy)
else:
ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
if self.dual_clip_ppo:
ratio = torch.min(ratio, self.dual_clip_coeff)
surr1 = ratio * adv_targ
surr2 = (
torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
)
surr_final = torch.min(surr1, surr2)
if self._use_policy_active_masks:
policy_action_loss = (
-torch.sum(surr_final, dim=-1, keepdim=True) * active_masks_batch
).sum() / active_masks_batch.sum()
else:
policy_action_loss = -torch.sum(surr_final, dim=-1, keepdim=True).mean()
if self._use_policy_vhead:
if isinstance(self.algo_module.models["actor"], DistributedDataParallel):
policy_value_normalizer = self.algo_module.models[
"actor"
].module.value_normalizer
else:
policy_value_normalizer = self.algo_module.models[
"actor"
].value_normalizer
policy_value_loss = self.cal_value_loss(
policy_value_normalizer,
policy_values,
value_preds_batch,
return_batch,
active_masks_batch,
)
policy_loss = (
policy_action_loss + policy_value_loss * self.policy_value_loss_coef
)
else:
policy_loss = policy_action_loss
# critic update
if self._use_share_model:
value_normalizer = self.algo_module.models["model"].value_normalizer
elif isinstance(self.algo_module.models["critic"], DistributedDataParallel):
value_normalizer = self.algo_module.models["critic"].module.value_normalizer
else:
value_normalizer = self.algo_module.get_critic_value_normalizer()
value_loss = self.cal_value_loss(
value_normalizer,
values,
value_preds_batch,
return_batch,
active_masks_batch,
)
loss_list = self.construct_loss_list(
policy_loss, dist_entropy, value_loss, turn_on
)
return loss_list, value_loss, policy_loss, dist_entropy, ratio
[docs] def get_data_generator(self, buffer, advantages):
if self._use_recurrent_policy:
if self.use_joint_action_loss:
data_generator = buffer.recurrent_generator_v3(
advantages, self.num_mini_batch, self.data_chunk_length
)
else:
data_generator = buffer.recurrent_generator(
advantages, self.num_mini_batch, self.data_chunk_length
)
elif self._use_naive_recurrent:
data_generator = buffer.naive_recurrent_generator(
advantages, self.num_mini_batch
)
else:
data_generator = buffer.feed_forward_generator(
advantages, self.num_mini_batch
)
return data_generator
[docs] def train_ppo(self, buffer, turn_on):
if self._use_popart or self._use_valuenorm:
if self._use_share_model:
value_normalizer = self.algo_module.models["model"].value_normalizer
elif isinstance(self.algo_module.models["critic"], DistributedDataParallel):
value_normalizer = self.algo_module.models[
"critic"
].module.value_normalizer
else:
value_normalizer = self.algo_module.get_critic_value_normalizer()
if value_normalizer is not None:
advantages = buffer.returns[:-1] - value_normalizer.denormalize(
buffer.value_preds[:-1]
)
else:
advantages = buffer.returns[:-1] - buffer.value_preds[:-1]
else:
advantages = buffer.returns[:-1] - buffer.value_preds[:-1]
if self._use_adv_normalize:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
advantages_copy = advantages.copy()
advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan
mean_advantages = np.nanmean(advantages_copy)
std_advantages = np.nanstd(advantages_copy)
advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)
train_info = {}
train_info["value_loss"] = 0
train_info["policy_loss"] = 0
train_info["dist_entropy"] = 0
train_info["actor_grad_norm"] = 0
train_info["critic_grad_norm"] = 0
train_info["ratio"] = 0
if self.world_size > 1:
train_info["reduced_value_loss"] = 0
train_info["reduced_policy_loss"] = 0
for _ in range(self.ppo_epoch):
data_generator = self.get_data_generator(buffer, advantages)
for sample in data_generator:
(
value_loss,
critic_grad_norm,
policy_loss,
dist_entropy,
actor_grad_norm,
ratio,
) = self.ppo_update(sample, turn_on)
if self.world_size > 1:
train_info["reduced_value_loss"] += reduce_tensor(
value_loss.data, self.world_size
)
train_info["reduced_policy_loss"] += reduce_tensor(
policy_loss.data, self.world_size
)
train_info["value_loss"] += value_loss.item()
train_info["policy_loss"] += policy_loss.item()
train_info["dist_entropy"] += dist_entropy.item()
train_info["actor_grad_norm"] += actor_grad_norm
train_info["critic_grad_norm"] += critic_grad_norm
train_info["ratio"] += ratio.mean().item()
num_updates = self.ppo_epoch * self.num_mini_batch
for k in train_info.keys():
train_info[k] /= num_updates
return train_info
[docs] def train(self, buffer, turn_on=True):
train_info = {}
for train_func in self.train_list:
train_info.update(train_func(buffer, turn_on))
for optimizer in self.algo_module.optimizers.values():
if hasattr(optimizer, "sync_lookahead"):
optimizer.sync_lookahead()
return train_info