Shortcuts

openrl.modules.networks package

Subpackages

Submodules

openrl.modules.networks.MAT_network module

class openrl.modules.networks.MAT_network.DecodeBlock(n_embd, n_head, n_agent)[source]

Bases: torch.nn.modules.module.Module

an unassuming Transformer block

forward(x, rep_enc)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.MAT_network.Decoder(obs_dim, action_dim, n_block, n_embd, n_head, n_agent, action_type='Discrete', dec_actor=False, share_actor=False)[source]

Bases: torch.nn.modules.module.Module

forward(action, obs_rep, obs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

zero_std(device)[source]
class openrl.modules.networks.MAT_network.EncodeBlock(n_embd, n_head, n_agent)[source]

Bases: torch.nn.modules.module.Module

an unassuming Transformer block

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.MAT_network.Encoder(state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state)[source]

Bases: torch.nn.modules.module.Module

forward(state, obs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.MAT_network.MultiAgentTransformer(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[source]

Bases: openrl.modules.networks.base_value_policy_network.BaseValuePolicyNetwork

eval_actions(obs, rnn_states, action, masks, action_masks=None, active_masks=None)[source]
get_actions(obs, rnn_states_actor=None, masks=None, action_masks=None, deterministic=False)[source]
get_actor_para()[source]
get_critic_para()[source]
get_values(critic_obs, rnn_states_critic=None, masks=None)[source]
zero_std()[source]
class openrl.modules.networks.MAT_network.SelfAttention(n_embd, n_head, n_agent, masked=False)[source]

Bases: torch.nn.modules.module.Module

forward(key, value, query)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.MAT_network.init_(m, gain=0.01, activate=False)[source]

openrl.modules.networks.base_policy_network module

class openrl.modules.networks.base_policy_network.BasePolicyNetwork(cfg, device)[source]

Bases: torch.nn.modules.module.Module

openrl.modules.networks.base_value_network module

class openrl.modules.networks.base_value_network.BaseValueNetwork(cfg, device)[source]

Bases: abc.ABC, torch.nn.modules.module.Module

abstract forward()[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.base_value_policy_network module

class openrl.modules.networks.base_value_policy_network.BaseValuePolicyNetwork(cfg, device)[source]

Bases: abc.ABC, torch.nn.modules.module.Module

abstract eval_actions(*args, **kwargs)[source]
forward(forward_type, *args, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

abstract get_actions(*args, **kwargs)[source]
get_actor_para()[source]
get_critic_para()[source]
abstract get_values(*args, **kwargs)[source]

openrl.modules.networks.ddpg_network module

class openrl.modules.networks.ddpg_network.ActorNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[source]

Bases: openrl.modules.networks.base_policy_network.BasePolicyNetwork

forward(obs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.ddpg_network.CriticNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[source]

Bases: openrl.modules.networks.base_value_network.BaseValueNetwork

forward(state, action, rnn_states, masks)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.ddpg_network.CriticNetwork_v0(cfg, input_space, action_space, device=device(type='cpu'), use_half=False)[source]

Bases: openrl.modules.networks.base_value_network.BaseValueNetwork

forward(state, action, rnn_states, masks)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.gail_discriminator module

class openrl.modules.networks.gail_discriminator.Discriminator(cfg, input_space, action_space, device, use_half, extra_args=None)[source]

Bases: torch.nn.modules.module.Module

compute_grad_pen(expert_state, expert_action, policy_state, policy_action, lambda_=10)[source]
predict_reward(state, action, gamma, masks, update_rms=True)[source]
update(expert_loader, buffer, obsfilt=None)[source]

openrl.modules.networks.policy_network module

class openrl.modules.networks.policy_network.PolicyNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[source]

Bases: openrl.modules.networks.base_policy_network.BasePolicyNetwork

eval_actions(obs, rnn_states, action, masks, action_masks=None, active_masks=None)[source]
forward(forward_type, *args, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_original(raw_obs, rnn_states, masks, action_masks=None, deterministic=False)[source]
get_policy_values(obs, rnn_states, masks)[source]

openrl.modules.networks.policy_network_gpt module

openrl.modules.networks.policy_value_network module

class openrl.modules.networks.policy_value_network.PolicyValueNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[source]

Bases: openrl.modules.networks.base_value_policy_network.BaseValuePolicyNetwork

eval_actions(obs, rnn_states, action, masks, action_masks, active_masks=None)[source]
get_actions(obs, rnn_states, masks, action_masks=None, deterministic=False)[source]
get_values(critic_obs, rnn_states, masks)[source]

openrl.modules.networks.policy_value_network_gpt module

openrl.modules.networks.policy_value_network_sb3 module

openrl.modules.networks.q_network module

class openrl.modules.networks.q_network.QNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[source]

Bases: openrl.modules.networks.base_value_network.BaseValueNetwork

forward(obs, rnn_states, masks, action_masks=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.sac_network module

class openrl.modules.networks.sac_network.SACActorNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None, log_std_min=- 20, log_std_max=2)[source]

Bases: openrl.modules.networks.ddpg_network.ActorNetwork

evaluate(obs, deterministic=True)[source]
forward(obs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.value_network module

class openrl.modules.networks.value_network.ValueNetwork(cfg, input_space, action_space=None, use_half=False, device=device(type='cpu'), extra_args=None)[source]

Bases: openrl.modules.networks.base_value_network.BaseValueNetwork

forward(critic_obs, rnn_states, masks)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.value_network_gpt module

openrl.modules.networks.vdn_network module

class openrl.modules.networks.vdn_network.VDNNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[source]

Bases: openrl.modules.networks.base_value_policy_network.BaseValuePolicyNetwork

eval_actions(obs, rnn_states, action, masks, action_masks, active_masks=None)[source]
eval_actions_target(obs, rnn_states, action, masks, action_masks, active_masks=None)[source]
eval_values(obs, rnn_states, masks, action_masks=None)[source]
forward(forward_type, *args, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_actions(*args, **kwargs)[source]
get_values(obs, rnn_states, masks, action_masks=None)[source]

Module contents