Shortcuts

Source code for openrl.modules.networks.utils.act

import torch
import torch.nn as nn

from .distributions import Bernoulli, Categorical, DiagGaussian


[docs]class ACTLayer(nn.Module): def __init__(self, action_space, inputs_dim, use_orthogonal, gain): super(ACTLayer, self).__init__() self.multidiscrete_action = False self.continuous_action = False self.mixed_action = False if action_space.__class__.__name__ == "Discrete": action_dim = action_space.n self.action_out = Categorical(inputs_dim, action_dim, use_orthogonal, gain) elif action_space.__class__.__name__ == "Box": self.continuous_action = True action_dim = action_space.shape[0] self.action_out = DiagGaussian(inputs_dim, action_dim, use_orthogonal, gain) elif action_space.__class__.__name__ == "MultiBinary": action_dim = action_space.shape[0] self.action_out = Bernoulli(inputs_dim, action_dim, use_orthogonal, gain) elif action_space.__class__.__name__ == "MultiDiscrete": self.multidiscrete_action = True action_dims = action_space.high - action_space.low + 1 self.action_outs = [] for action_dim in action_dims: self.action_outs.append( Categorical(inputs_dim, action_dim, use_orthogonal, gain) ) self.action_outs = nn.ModuleList(self.action_outs) else: # discrete + continous self.mixed_action = True continous_dim = action_space[0].shape[0] discrete_dim = action_space[1].n self.action_outs = nn.ModuleList( [ DiagGaussian(inputs_dim, continous_dim, use_orthogonal, gain), Categorical(inputs_dim, discrete_dim, use_orthogonal, gain), ] )
[docs] def forward(self, x, action_masks=None, deterministic=False): if self.mixed_action: actions = [] action_log_probs = [] for action_out in self.action_outs: action_logit = action_out(x) action = action_logit.mode() if deterministic else action_logit.sample() action_log_prob = action_logit.log_probs(action) actions.append(action.float()) action_log_probs.append(action_log_prob) actions = torch.cat(actions, -1) action_log_probs = torch.sum( torch.cat(action_log_probs, -1), -1, keepdim=True ) elif self.multidiscrete_action: actions = [] action_log_probs = [] for action_out in self.action_outs: action_logit = action_out(x) action = action_logit.mode() if deterministic else action_logit.sample() action_log_prob = action_logit.log_probs(action) actions.append(action) action_log_probs.append(action_log_prob) actions = torch.cat(actions, -1) action_log_probs = torch.cat(action_log_probs, -1) elif self.continuous_action: action_logits = self.action_out(x) actions = action_logits.mode() if deterministic else action_logits.sample() action_log_probs = action_logits.log_probs(actions) else: action_logits = self.action_out(x, action_masks) actions = action_logits.mode() if deterministic else action_logits.sample() action_log_probs = action_logits.log_probs(actions) return actions, action_log_probs
[docs] def get_probs(self, x, action_masks=None): if self.mixed_action or self.multidiscrete_action: action_probs = [] for action_out in self.action_outs: action_logit = action_out(x) action_prob = action_logit.probs action_probs.append(action_prob) action_probs = torch.cat(action_probs, -1) elif self.continuous_action: action_logits = self.action_out(x) action_probs = action_logits.probs else: action_logits = self.action_out(x, action_masks) action_probs = action_logits.probs return action_probs
[docs] def evaluate_actions( self, x, action, action_masks=None, active_masks=None, get_probs=False ): if self.mixed_action: a, b = action.split((2, 1), -1) b = b.long() action = [a, b] action_log_probs = [] dist_entropy = [] for action_out, act in zip(self.action_outs, action): action_logit = action_out(x) action_log_probs.append(action_logit.log_probs(act)) if active_masks is not None: if len(action_logit.entropy().shape) == len(active_masks.shape): dist_entropy.append( (action_logit.entropy() * active_masks).sum() / active_masks.sum() ) else: dist_entropy.append( (action_logit.entropy() * active_masks.squeeze(-1)).sum() / active_masks.sum() ) else: dist_entropy.append(action_logit.entropy().mean()) action_log_probs = torch.sum( torch.cat(action_log_probs, -1), -1, keepdim=True ) dist_entropy = dist_entropy[0] * 0.0025 + dist_entropy[1] * 0.01 elif self.multidiscrete_action: action = torch.transpose(action, 0, 1) action_log_probs = [] dist_entropy = [] for action_out, act in zip(self.action_outs, action): action_logit = action_out(x) action_log_probs.append(action_logit.log_probs(act)) if active_masks is not None: dist_entropy.append( (action_logit.entropy() * active_masks.squeeze(-1)).sum() / active_masks.sum() ) else: dist_entropy.append(action_logit.entropy().mean()) action_log_probs = torch.cat(action_log_probs, -1) # ! could be wrong dist_entropy = torch.tensor(dist_entropy).mean() elif self.continuous_action: action_logits = self.action_out(x) action_log_probs = action_logits.log_probs(action) act_entropy = action_logits.entropy() if active_masks is not None: dist_entropy = (act_entropy * active_masks).sum() / active_masks.sum() else: dist_entropy = act_entropy.mean() else: action_logits = self.action_out(x, action_masks) action_log_probs = action_logits.log_probs(action) if active_masks is not None: dist_entropy = ( action_logits.entropy() * active_masks.squeeze(-1) ).sum() / active_masks.sum() else: dist_entropy = action_logits.entropy().mean() if not get_probs: return action_log_probs, dist_entropy else: return action_log_probs, dist_entropy, action_logits