Source code for openrl.modules.networks.utils.act
import torch
import torch.nn as nn
from .distributions import Bernoulli, Categorical, DiagGaussian
[docs]class ACTLayer(nn.Module):
def __init__(self, action_space, inputs_dim, use_orthogonal, gain):
super(ACTLayer, self).__init__()
self.multidiscrete_action = False
self.continuous_action = False
self.mixed_action = False
if action_space.__class__.__name__ == "Discrete":
action_dim = action_space.n
self.action_out = Categorical(inputs_dim, action_dim, use_orthogonal, gain)
elif action_space.__class__.__name__ == "Box":
self.continuous_action = True
action_dim = action_space.shape[0]
self.action_out = DiagGaussian(inputs_dim, action_dim, use_orthogonal, gain)
elif action_space.__class__.__name__ == "MultiBinary":
action_dim = action_space.shape[0]
self.action_out = Bernoulli(inputs_dim, action_dim, use_orthogonal, gain)
elif action_space.__class__.__name__ == "MultiDiscrete":
self.multidiscrete_action = True
action_dims = action_space.high - action_space.low + 1
self.action_outs = []
for action_dim in action_dims:
self.action_outs.append(
Categorical(inputs_dim, action_dim, use_orthogonal, gain)
)
self.action_outs = nn.ModuleList(self.action_outs)
else: # discrete + continous
self.mixed_action = True
continous_dim = action_space[0].shape[0]
discrete_dim = action_space[1].n
self.action_outs = nn.ModuleList(
[
DiagGaussian(inputs_dim, continous_dim, use_orthogonal, gain),
Categorical(inputs_dim, discrete_dim, use_orthogonal, gain),
]
)
[docs] def forward(self, x, action_masks=None, deterministic=False):
if self.mixed_action:
actions = []
action_log_probs = []
for action_out in self.action_outs:
action_logit = action_out(x)
action = action_logit.mode() if deterministic else action_logit.sample()
action_log_prob = action_logit.log_probs(action)
actions.append(action.float())
action_log_probs.append(action_log_prob)
actions = torch.cat(actions, -1)
action_log_probs = torch.sum(
torch.cat(action_log_probs, -1), -1, keepdim=True
)
elif self.multidiscrete_action:
actions = []
action_log_probs = []
for action_out in self.action_outs:
action_logit = action_out(x)
action = action_logit.mode() if deterministic else action_logit.sample()
action_log_prob = action_logit.log_probs(action)
actions.append(action)
action_log_probs.append(action_log_prob)
actions = torch.cat(actions, -1)
action_log_probs = torch.cat(action_log_probs, -1)
elif self.continuous_action:
action_logits = self.action_out(x)
actions = action_logits.mode() if deterministic else action_logits.sample()
action_log_probs = action_logits.log_probs(actions)
else:
action_logits = self.action_out(x, action_masks)
actions = action_logits.mode() if deterministic else action_logits.sample()
action_log_probs = action_logits.log_probs(actions)
return actions, action_log_probs
[docs] def get_probs(self, x, action_masks=None):
if self.mixed_action or self.multidiscrete_action:
action_probs = []
for action_out in self.action_outs:
action_logit = action_out(x)
action_prob = action_logit.probs
action_probs.append(action_prob)
action_probs = torch.cat(action_probs, -1)
elif self.continuous_action:
action_logits = self.action_out(x)
action_probs = action_logits.probs
else:
action_logits = self.action_out(x, action_masks)
action_probs = action_logits.probs
return action_probs
[docs] def evaluate_actions(
self, x, action, action_masks=None, active_masks=None, get_probs=False
):
if self.mixed_action:
a, b = action.split((2, 1), -1)
b = b.long()
action = [a, b]
action_log_probs = []
dist_entropy = []
for action_out, act in zip(self.action_outs, action):
action_logit = action_out(x)
action_log_probs.append(action_logit.log_probs(act))
if active_masks is not None:
if len(action_logit.entropy().shape) == len(active_masks.shape):
dist_entropy.append(
(action_logit.entropy() * active_masks).sum()
/ active_masks.sum()
)
else:
dist_entropy.append(
(action_logit.entropy() * active_masks.squeeze(-1)).sum()
/ active_masks.sum()
)
else:
dist_entropy.append(action_logit.entropy().mean())
action_log_probs = torch.sum(
torch.cat(action_log_probs, -1), -1, keepdim=True
)
dist_entropy = dist_entropy[0] * 0.0025 + dist_entropy[1] * 0.01
elif self.multidiscrete_action:
action = torch.transpose(action, 0, 1)
action_log_probs = []
dist_entropy = []
for action_out, act in zip(self.action_outs, action):
action_logit = action_out(x)
action_log_probs.append(action_logit.log_probs(act))
if active_masks is not None:
dist_entropy.append(
(action_logit.entropy() * active_masks.squeeze(-1)).sum()
/ active_masks.sum()
)
else:
dist_entropy.append(action_logit.entropy().mean())
action_log_probs = torch.cat(action_log_probs, -1) # ! could be wrong
dist_entropy = torch.tensor(dist_entropy).mean()
elif self.continuous_action:
action_logits = self.action_out(x)
action_log_probs = action_logits.log_probs(action)
act_entropy = action_logits.entropy()
if active_masks is not None:
dist_entropy = (act_entropy * active_masks).sum() / active_masks.sum()
else:
dist_entropy = act_entropy.mean()
else:
action_logits = self.action_out(x, action_masks)
action_log_probs = action_logits.log_probs(action)
if active_masks is not None:
dist_entropy = (
action_logits.entropy() * active_masks.squeeze(-1)
).sum() / active_masks.sum()
else:
dist_entropy = action_logits.entropy().mean()
if not get_probs:
return action_log_probs, dist_entropy
else:
return action_log_probs, dist_entropy, action_logits