Shortcuts

openrl.modules.networks package

Subpackages

Submodules

openrl.modules.networks.base_policy_network module

class openrl.modules.networks.base_policy_network.BasePolicyNetwork(cfg, device)[源代码]

基类:torch.nn.modules.module.Module

openrl.modules.networks.base_value_network module

class openrl.modules.networks.base_value_network.BaseValueNetwork(cfg, device)[源代码]

基类:abc.ABC, torch.nn.modules.module.Module

abstract forward()[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.policy_network module

class openrl.modules.networks.policy_network.PolicyNetwork(cfg, input_space, action_space, device=device(type='cpu'), use_half=False)[源代码]

基类:openrl.modules.networks.base_policy_network.BasePolicyNetwork

eval_actions(obs, rnn_states, action, masks, available_actions=None, active_masks=None)[源代码]
forward(forward_type, *args, **kwargs)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_original(raw_obs, rnn_states, masks, available_actions=None, deterministic=False)[源代码]
get_policy_values(obs, rnn_states, masks)[源代码]

openrl.modules.networks.policy_value_network module

class openrl.modules.networks.policy_value_network.PolicyValueNetwork(cfg, obs_space, critic_obs_space, action_space, device=device(type='cpu'), use_half=False)[源代码]

基类:torch.nn.modules.module.Module

evaluate_actions(obs, rnn_states, action, masks, available_actions, active_masks=None)[源代码]
get_actions(obs, rnn_states, masks, available_actions=None, deterministic=False)[源代码]
get_values(critic_obs, rnn_states, masks)[源代码]

openrl.modules.networks.policy_value_network_gpt module

openrl.modules.networks.value_network module

class openrl.modules.networks.value_network.ValueNetwork(cfg, input_space, action_space=None, use_half=False, device=device(type='cpu'))[源代码]

基类:openrl.modules.networks.base_value_network.BaseValueNetwork

forward(critic_obs, rnn_states, masks)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Module contents