Source code for openrl.modules.networks.utils.distributions
import torch
import torch.nn as nn
from .util import init
"""
Modify standard PyTorch distributions so they are compatible with this code.
"""
#
# Standardize distribution interfaces
#
# Categorical
[docs]class FixedCategorical(torch.distributions.Categorical):
[docs] def log_probs(self, actions):
return (
super()
.log_prob(actions.squeeze(-1))
.view(actions.size(0), -1)
.sum(-1)
.unsqueeze(-1)
)
# Normal
[docs]class FixedNormal(torch.distributions.Normal):
[docs] def log_probs(self, actions):
# return super().log_prob(actions).sum(-1, keepdim=True)
return super().log_prob(actions)
# Bernoulli
[docs]class FixedBernoulli(torch.distributions.Bernoulli):
[docs] def log_probs(self, actions):
return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1)
[docs]class Categorical(nn.Module):
def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
super(Categorical, self).__init__()
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
def init_(m):
return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)
self.linear = init_(nn.Linear(num_inputs, num_outputs))
[docs] def forward(self, x, action_masks=None):
x = self.linear(x)
if action_masks is not None:
x[action_masks == 0] = -6e4 # fp16
return FixedCategorical(logits=x)
[docs]class DiagGaussian(nn.Module):
def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
super(DiagGaussian, self).__init__()
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
def init_(m):
return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)
self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))
self.logstd = AddBias(torch.zeros(num_outputs))
[docs] def forward(self, x):
action_mean = self.fc_mean(x)
# An ugly hack for my KFAC implementation.
# zeros = torch.zeros(action_mean.size())
zeros = torch.zeros_like(action_mean)
# if x.is_cuda:
# zeros = zeros.cuda()
action_logstd = self.logstd(zeros)
return FixedNormal(action_mean, action_logstd.exp())
[docs]class Bernoulli(nn.Module):
def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
super(Bernoulli, self).__init__()
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
def init_(m):
return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)
self.linear = init_(nn.Linear(num_inputs, num_outputs))