Source code for openrl.modules.networks.utils.transformer_act
import torch
from torch.distributions import Categorical, Normal
from torch.nn import functional as F
[docs]def discrete_autoregreesive_act(
decoder,
obs_rep,
obs,
batch_size,
n_agent,
action_dim,
tpdv,
action_masks=None,
deterministic=False,
):
shifted_action = torch.zeros((batch_size, n_agent, action_dim + 1)).to(**tpdv)
shifted_action[:, 0, 0] = 1
output_action = torch.zeros((batch_size, n_agent, 1), dtype=torch.long)
output_action_log = torch.zeros_like(output_action, dtype=torch.float32)
for i in range(n_agent):
logit = decoder(shifted_action, obs_rep, obs)[:, i, :]
if action_masks is not None:
logit[action_masks[:, i, :] == 0] = -1e10
distri = Categorical(logits=logit)
action = distri.probs.argmax(dim=-1) if deterministic else distri.sample()
action_log = distri.log_prob(action)
output_action[:, i, :] = action.unsqueeze(-1)
output_action_log[:, i, :] = action_log.unsqueeze(-1)
if i + 1 < n_agent:
shifted_action[:, i + 1, 1:] = F.one_hot(action, num_classes=action_dim)
return output_action, output_action_log
[docs]def discrete_parallel_act(
decoder,
obs_rep,
obs,
action,
batch_size,
n_agent,
action_dim,
tpdv,
action_masks=None,
):
one_hot_action = F.one_hot(
action.squeeze(-1), num_classes=action_dim
) # (batch, n_agent, action_dim)
shifted_action = torch.zeros((batch_size, n_agent, action_dim + 1)).to(**tpdv)
shifted_action[:, 0, 0] = 1
shifted_action[:, 1:, 1:] = one_hot_action[:, :-1, :]
logit = decoder(shifted_action, obs_rep, obs)
if action_masks is not None:
logit[action_masks == 0] = -1e10
distri = Categorical(logits=logit)
action_log = distri.log_prob(action.squeeze(-1)).unsqueeze(-1)
entropy = distri.entropy().unsqueeze(-1)
return action_log, entropy
[docs]def continuous_autoregreesive_act(
decoder, obs_rep, obs, batch_size, n_agent, action_dim, tpdv, deterministic=False
):
shifted_action = torch.zeros((batch_size, n_agent, action_dim)).to(**tpdv)
output_action = torch.zeros((batch_size, n_agent, action_dim), dtype=torch.float32)
output_action_log = torch.zeros_like(output_action, dtype=torch.float32)
for i in range(n_agent):
act_mean = decoder(shifted_action, obs_rep, obs)[:, i, :]
action_std = torch.sigmoid(decoder.log_std) * 0.5
distri = Normal(act_mean, action_std)
action = act_mean if deterministic else distri.sample()
action_log = distri.log_prob(action)
output_action[:, i, :] = action
output_action_log[:, i, :] = action_log
if i + 1 < n_agent:
shifted_action[:, i + 1, :] = action
return output_action, output_action_log
[docs]def continuous_parallel_act(
decoder, obs_rep, obs, action, batch_size, n_agent, action_dim, tpdv
):
shifted_action = torch.zeros((batch_size, n_agent, action_dim)).to(**tpdv)
shifted_action[:, 1:, :] = action[:, :-1, :]
act_mean = decoder(shifted_action, obs_rep, obs)
action_std = torch.sigmoid(decoder.log_std) * 0.5
distri = Normal(act_mean, action_std)
action_log = distri.log_prob(action)
entropy = distri.entropy()
return action_log, entropy