Shortcuts

Source code for openrl.modules.networks.utils.transformer_act

import torch
from torch.distributions import Categorical, Normal
from torch.nn import functional as F


[docs]def discrete_autoregreesive_act( decoder, obs_rep, obs, batch_size, n_agent, action_dim, tpdv, action_masks=None, deterministic=False, ): shifted_action = torch.zeros((batch_size, n_agent, action_dim + 1)).to(**tpdv) shifted_action[:, 0, 0] = 1 output_action = torch.zeros((batch_size, n_agent, 1), dtype=torch.long) output_action_log = torch.zeros_like(output_action, dtype=torch.float32) for i in range(n_agent): logit = decoder(shifted_action, obs_rep, obs)[:, i, :] if action_masks is not None: logit[action_masks[:, i, :] == 0] = -1e10 distri = Categorical(logits=logit) action = distri.probs.argmax(dim=-1) if deterministic else distri.sample() action_log = distri.log_prob(action) output_action[:, i, :] = action.unsqueeze(-1) output_action_log[:, i, :] = action_log.unsqueeze(-1) if i + 1 < n_agent: shifted_action[:, i + 1, 1:] = F.one_hot(action, num_classes=action_dim) return output_action, output_action_log
[docs]def discrete_parallel_act( decoder, obs_rep, obs, action, batch_size, n_agent, action_dim, tpdv, action_masks=None, ): one_hot_action = F.one_hot( action.squeeze(-1), num_classes=action_dim ) # (batch, n_agent, action_dim) shifted_action = torch.zeros((batch_size, n_agent, action_dim + 1)).to(**tpdv) shifted_action[:, 0, 0] = 1 shifted_action[:, 1:, 1:] = one_hot_action[:, :-1, :] logit = decoder(shifted_action, obs_rep, obs) if action_masks is not None: logit[action_masks == 0] = -1e10 distri = Categorical(logits=logit) action_log = distri.log_prob(action.squeeze(-1)).unsqueeze(-1) entropy = distri.entropy().unsqueeze(-1) return action_log, entropy
[docs]def continuous_autoregreesive_act( decoder, obs_rep, obs, batch_size, n_agent, action_dim, tpdv, deterministic=False ): shifted_action = torch.zeros((batch_size, n_agent, action_dim)).to(**tpdv) output_action = torch.zeros((batch_size, n_agent, action_dim), dtype=torch.float32) output_action_log = torch.zeros_like(output_action, dtype=torch.float32) for i in range(n_agent): act_mean = decoder(shifted_action, obs_rep, obs)[:, i, :] action_std = torch.sigmoid(decoder.log_std) * 0.5 distri = Normal(act_mean, action_std) action = act_mean if deterministic else distri.sample() action_log = distri.log_prob(action) output_action[:, i, :] = action output_action_log[:, i, :] = action_log if i + 1 < n_agent: shifted_action[:, i + 1, :] = action return output_action, output_action_log
[docs]def continuous_parallel_act( decoder, obs_rep, obs, action, batch_size, n_agent, action_dim, tpdv ): shifted_action = torch.zeros((batch_size, n_agent, action_dim)).to(**tpdv) shifted_action[:, 1:, :] = action[:, :-1, :] act_mean = decoder(shifted_action, obs_rep, obs) action_std = torch.sigmoid(decoder.log_std) * 0.5 distri = Normal(act_mean, action_std) action_log = distri.log_prob(action) entropy = distri.entropy() return action_log, entropy