Shortcuts

Source code for openrl.modules.networks.gail_discriminator

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch import autograd

from openrl.modules.networks.utils.mlp import MLPLayer
from openrl.modules.networks.utils.running_mean_std import RunningMeanStd


[docs]class Discriminator(nn.Module): def __init__( self, cfg, input_space, action_space, device, use_half, extra_args=None, ): super(Discriminator, self).__init__() hidden_dim = cfg.gail_hidden_size layer_num = cfg.gail_layer_num self.cfg = cfg self.device = device self.critic_obs_process_func = ( extra_args["critic_obs_process_func"] if extra_args is not None and "critic_obs_process_func" in extra_args else lambda _: _ ) self.base = MLPLayer( input_space, hidden_dim, layer_N=layer_num, use_orthogonal=cfg.use_orthogonal, activation_id=cfg.activation_id, ) self.gail_out = nn.Linear(hidden_dim, 1) self.gail_out.weight.data.mul_(0.1) self.gail_out.bias.data.mul_(0.0) self.optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.gail_lr) self.returns = None self.ret_rms = RunningMeanStd(shape=()) self.first_train = True self.to(device) self.train()
[docs] def compute_grad_pen( self, expert_state, expert_action, policy_state, policy_action, lambda_=10 ): alpha = torch.rand(expert_state.size(0), 1) if self.cfg.gail_use_action: expert_data = torch.cat([expert_state, expert_action], dim=-1) policy_data = torch.cat([policy_state, policy_action], dim=-1) else: expert_data = expert_state policy_data = policy_state alpha = alpha.expand_as(expert_data).to(expert_data.device) mixup_data = alpha * expert_data + (1 - alpha) * policy_data mixup_data.requires_grad = True disc = self.gail_out(self.base(mixup_data)) ones = torch.ones(disc.size()).to(disc.device) grad = autograd.grad( outputs=disc, inputs=mixup_data, grad_outputs=ones, create_graph=True, retain_graph=True, only_inputs=True, )[0] grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean() return grad_pen
[docs] def update(self, expert_loader, buffer, obsfilt=None): self.train() policy_data_generator = buffer.feed_forward_critic_obs_generator( None, mini_batch_size=expert_loader.batch_size, critic_obs_process_func=self.critic_obs_process_func, ) loss = 0 n = 0 for expert_batch, policy_batch in zip(expert_loader, policy_data_generator): policy_state, policy_action = policy_batch[0], policy_batch[4] policy_state = torch.from_numpy(policy_state).to(self.device) if self.cfg.gail_use_action: policy_action = torch.from_numpy(policy_action).to(self.device) if self.cfg.gail_use_action: policy_d = self.gail_out( self.base(torch.cat([policy_state, policy_action], dim=-1)) ) else: policy_d = self.gail_out(self.base(policy_state)) expert_state, expert_action = expert_batch expert_state = expert_state.reshape(-1, *expert_state.shape[2:]) expert_action = expert_action.reshape(-1, *expert_action.shape[2:]) if obsfilt is not None: expert_state = obsfilt(expert_state.numpy(), update=False) expert_state = torch.FloatTensor(expert_state).to(self.device) else: expert_state = expert_state.to(self.device) if self.cfg.gail_use_action: expert_action = expert_action.to(self.device) expert_d = self.gail_out( self.base(torch.cat([expert_state, expert_action], dim=-1)) ) else: expert_d = self.gail_out(self.base(expert_state)) expert_loss = F.binary_cross_entropy_with_logits( expert_d, torch.zeros(expert_d.size()).to(self.device) ) policy_loss = F.binary_cross_entropy_with_logits( policy_d, torch.ones(policy_d.size()).to(self.device) ) gail_loss = expert_loss + policy_loss grad_pen = self.compute_grad_pen( expert_state, expert_action, policy_state, policy_action ) if not self.first_train: loss += (gail_loss + grad_pen).item() n += 1 else: self.first_train = False self.optimizer.zero_grad() (gail_loss + grad_pen).backward() self.optimizer.step() return loss / n if n > 0 else 0
[docs] def predict_reward(self, state, action, gamma, masks, update_rms=True): with torch.no_grad(): self.eval() state_shape = state.shape masks_shape = masks.shape state = self.critic_obs_process_func(state.reshape(-1, state_shape[-1])) state = torch.from_numpy(state).to(self.device) masks = torch.from_numpy(masks).to(self.device).reshape(-1, masks_shape[-1]) if self.cfg.gail_use_action: action_shape = action.shape action = ( torch.from_numpy(action) .to(self.device) .reshape(-1, action_shape[-1]) ) d = self.gail_out(self.base(torch.cat([state, action], dim=-1))) else: d = self.gail_out(self.base(state)) s = torch.sigmoid(d) + 1e-8 reward = -s.log() if self.returns is None: self.returns = reward.clone() if update_rms: self.returns = self.returns * masks * gamma + reward self.ret_rms.update(self.returns.cpu().numpy()) reward = reward / np.sqrt(self.ret_rms.var[0] + 1e-8) reward = reward.reshape((*state_shape[:2], reward.shape[-1])).cpu().numpy() return reward