Shortcuts

openrl.modules.networks package

Subpackages

Submodules

openrl.modules.networks.MAT_network module

class openrl.modules.networks.MAT_network.DecodeBlock(n_embd, n_head, n_agent)[源代码]

基类:torch.nn.modules.module.Module

an unassuming Transformer block

forward(x, rep_enc)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.MAT_network.Decoder(obs_dim, action_dim, n_block, n_embd, n_head, n_agent, action_type='Discrete', dec_actor=False, share_actor=False)[源代码]

基类:torch.nn.modules.module.Module

forward(action, obs_rep, obs)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

zero_std(device)[源代码]
class openrl.modules.networks.MAT_network.EncodeBlock(n_embd, n_head, n_agent)[源代码]

基类:torch.nn.modules.module.Module

an unassuming Transformer block

forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.MAT_network.Encoder(state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state)[源代码]

基类:torch.nn.modules.module.Module

forward(state, obs)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.MAT_network.MultiAgentTransformer(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[源代码]

基类:openrl.modules.networks.base_value_policy_network.BaseValuePolicyNetwork

eval_actions(obs, rnn_states, action, masks, action_masks=None, active_masks=None)[源代码]
get_actions(obs, rnn_states_actor=None, masks=None, action_masks=None, deterministic=False)[源代码]
get_actor_para()[源代码]
get_critic_para()[源代码]
get_values(critic_obs, rnn_states_critic=None, masks=None)[源代码]
zero_std()[源代码]
class openrl.modules.networks.MAT_network.SelfAttention(n_embd, n_head, n_agent, masked=False)[源代码]

基类:torch.nn.modules.module.Module

forward(key, value, query)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.MAT_network.init_(m, gain=0.01, activate=False)[源代码]

openrl.modules.networks.base_policy_network module

class openrl.modules.networks.base_policy_network.BasePolicyNetwork(cfg, device)[源代码]

基类:torch.nn.modules.module.Module

openrl.modules.networks.base_value_network module

class openrl.modules.networks.base_value_network.BaseValueNetwork(cfg, device)[源代码]

基类:abc.ABC, torch.nn.modules.module.Module

abstract forward()[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.base_value_policy_network module

class openrl.modules.networks.base_value_policy_network.BaseValuePolicyNetwork(cfg, device)[源代码]

基类:abc.ABC, torch.nn.modules.module.Module

abstract eval_actions(*args, **kwargs)[源代码]
forward(forward_type, *args, **kwargs)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

abstract get_actions(*args, **kwargs)[源代码]
get_actor_para()[源代码]
get_critic_para()[源代码]
abstract get_values(*args, **kwargs)[源代码]

openrl.modules.networks.ddpg_network module

class openrl.modules.networks.ddpg_network.ActorNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[源代码]

基类:openrl.modules.networks.base_policy_network.BasePolicyNetwork

forward(obs)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.ddpg_network.CriticNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[源代码]

基类:openrl.modules.networks.base_value_network.BaseValueNetwork

forward(state, action, rnn_states, masks)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.ddpg_network.CriticNetwork_v0(cfg, input_space, action_space, device=device(type='cpu'), use_half=False)[源代码]

基类:openrl.modules.networks.base_value_network.BaseValueNetwork

forward(state, action, rnn_states, masks)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.gail_discriminator module

class openrl.modules.networks.gail_discriminator.Discriminator(cfg, input_space, action_space, device, use_half, extra_args=None)[源代码]

基类:torch.nn.modules.module.Module

compute_grad_pen(expert_state, expert_action, policy_state, policy_action, lambda_=10)[源代码]
predict_reward(state, action, gamma, masks, update_rms=True)[源代码]
update(expert_loader, buffer, obsfilt=None)[源代码]

openrl.modules.networks.policy_network module

class openrl.modules.networks.policy_network.PolicyNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[源代码]

基类:openrl.modules.networks.base_policy_network.BasePolicyNetwork

eval_actions(obs, rnn_states, action, masks, action_masks=None, active_masks=None)[源代码]
forward(forward_type, *args, **kwargs)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_original(raw_obs, rnn_states, masks, action_masks=None, deterministic=False)[源代码]
get_policy_values(obs, rnn_states, masks)[源代码]

openrl.modules.networks.policy_network_gpt module

openrl.modules.networks.policy_value_network module

class openrl.modules.networks.policy_value_network.PolicyValueNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[源代码]

基类:openrl.modules.networks.base_value_policy_network.BaseValuePolicyNetwork

eval_actions(obs, rnn_states, action, masks, action_masks, active_masks=None)[源代码]
get_actions(obs, rnn_states, masks, action_masks=None, deterministic=False)[源代码]
get_values(critic_obs, rnn_states, masks)[源代码]

openrl.modules.networks.policy_value_network_gpt module

openrl.modules.networks.policy_value_network_sb3 module

openrl.modules.networks.q_network module

class openrl.modules.networks.q_network.QNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[源代码]

基类:openrl.modules.networks.base_value_network.BaseValueNetwork

forward(obs, rnn_states, masks, action_masks=None)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.sac_network module

class openrl.modules.networks.sac_network.SACActorNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None, log_std_min=- 20, log_std_max=2)[源代码]

基类:openrl.modules.networks.ddpg_network.ActorNetwork

evaluate(obs, deterministic=True)[源代码]
forward(obs)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.value_network module

class openrl.modules.networks.value_network.ValueNetwork(cfg, input_space, action_space=None, use_half=False, device=device(type='cpu'), extra_args=None)[源代码]

基类:openrl.modules.networks.base_value_network.BaseValueNetwork

forward(critic_obs, rnn_states, masks)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.value_network_gpt module

openrl.modules.networks.vdn_network module

class openrl.modules.networks.vdn_network.VDNNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False, extra_args=None)[源代码]

基类:openrl.modules.networks.base_value_policy_network.BaseValuePolicyNetwork

eval_actions(obs, rnn_states, action, masks, action_masks, active_masks=None)[源代码]
eval_actions_target(obs, rnn_states, action, masks, action_masks, active_masks=None)[源代码]
eval_values(obs, rnn_states, masks, action_masks=None)[源代码]
forward(forward_type, *args, **kwargs)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_actions(*args, **kwargs)[源代码]
get_values(obs, rnn_states, masks, action_masks=None)[源代码]

Module contents