Shortcuts

Source code for openrl.modules.networks.utils.distributions

import torch
import torch.nn as nn

from .util import init

"""
Modify standard PyTorch distributions so they are compatible with this code.
"""

#
# Standardize distribution interfaces
#


# Categorical
[docs]class FixedCategorical(torch.distributions.Categorical):
[docs] def sample(self): return super().sample().unsqueeze(-1)
[docs] def log_probs(self, actions): return ( super() .log_prob(actions.squeeze(-1)) .view(actions.size(0), -1) .sum(-1) .unsqueeze(-1) )
[docs] def mode(self): return self.probs.argmax(dim=-1, keepdim=True)
# Normal
[docs]class FixedNormal(torch.distributions.Normal):
[docs] def log_probs(self, actions): # return super().log_prob(actions).sum(-1, keepdim=True) return super().log_prob(actions)
[docs] def entropy(self): return super().entropy()
[docs] def mode(self): return self.mean
# Bernoulli
[docs]class FixedBernoulli(torch.distributions.Bernoulli):
[docs] def log_probs(self, actions): return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1)
[docs] def entropy(self): return super().entropy().sum(-1)
[docs] def mode(self): return torch.gt(self.probs, 0.5).float()
[docs]class Categorical(nn.Module): def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): super(Categorical, self).__init__() init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] def init_(m): return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) self.linear = init_(nn.Linear(num_inputs, num_outputs))
[docs] def forward(self, x, action_masks=None): x = self.linear(x) if action_masks is not None: x[action_masks == 0] = -6e4 # fp16 return FixedCategorical(logits=x)
[docs]class DiagGaussian(nn.Module): def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): super(DiagGaussian, self).__init__() init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] def init_(m): return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) self.logstd = AddBias(torch.zeros(num_outputs))
[docs] def forward(self, x): action_mean = self.fc_mean(x) # An ugly hack for my KFAC implementation. # zeros = torch.zeros(action_mean.size()) zeros = torch.zeros_like(action_mean) # if x.is_cuda: # zeros = zeros.cuda() action_logstd = self.logstd(zeros) return FixedNormal(action_mean, action_logstd.exp())
[docs]class Bernoulli(nn.Module): def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): super(Bernoulli, self).__init__() init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] def init_(m): return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) self.linear = init_(nn.Linear(num_inputs, num_outputs))
[docs] def forward(self, x): x = self.linear(x) return FixedBernoulli(logits=x)
[docs]class AddBias(nn.Module): def __init__(self, bias): super(AddBias, self).__init__() self._bias = nn.Parameter(bias.unsqueeze(1))
[docs] def forward(self, x): if x.dim() == 2: bias = self._bias.t().view(1, -1) else: bias = self._bias.t().view(1, -1, 1, 1) return x + bias