Shortcuts

openrl.modules.networks.utils package

Subpackages

Submodules

openrl.modules.networks.utils.act module

class openrl.modules.networks.utils.act.ACTLayer(action_space, inputs_dim, use_orthogonal, gain)[源代码]

基类:torch.nn.modules.module.Module

evaluate_actions(x, action, action_masks=None, active_masks=None, get_probs=False)[源代码]
forward(x, action_masks=None, deterministic=False)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_probs(x, action_masks=None)[源代码]

openrl.modules.networks.utils.attention module

class openrl.modules.networks.utils.attention.CatSelfEmbedding(split_shape, d_model, use_orthogonal=True, activation_id=1)[源代码]

基类:torch.nn.modules.module.Module

forward(x, self_idx=- 1)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.utils.attention.Embedding(split_shape, d_model, use_orthogonal=True, activation_id=1)[源代码]

基类:torch.nn.modules.module.Module

forward(x, self_idx=None)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.utils.attention.Encoder(cfg, split_shape, cat_self=True)[源代码]

基类:torch.nn.modules.module.Module

forward(x, self_idx=- 1, mask=None)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.utils.attention.EncoderLayer(d_model, heads, dropout=0.0, use_orthogonal=True, activation_id=False, d_ff=512, use_FF=False)[源代码]

基类:torch.nn.modules.module.Module

forward(x, mask)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.utils.attention.FeedForward(d_model, d_ff=512, dropout=0.0, use_orthogonal=True, activation_id=1)[源代码]

基类:torch.nn.modules.module.Module

forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.utils.attention.MultiHeadAttention(heads, d_model, dropout=0.0, use_orthogonal=True)[源代码]

基类:torch.nn.modules.module.Module

forward(q, k, v, mask=None)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.utils.attention.ScaledDotProductAttention(q, k, v, d_k, mask=None, dropout=None)[源代码]
openrl.modules.networks.utils.attention.split_obs(obs, split_shape)[源代码]

openrl.modules.networks.utils.cnn module

class openrl.modules.networks.utils.cnn.CNNBase(cfg, obs_shape)[源代码]

基类:torch.nn.modules.module.Module

forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

property output_size
class openrl.modules.networks.utils.cnn.CNNLayer(obs_shape, hidden_size, use_orthogonal, activation_id, kernel_size=3, stride=1)[源代码]

基类:torch.nn.modules.module.Module

calc_flatten_size(h, w, filter, stride)[源代码]
forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.utils.cnn.Flatten(*args, **kwargs)[源代码]

基类:torch.nn.modules.module.Module

forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.utils.distributed_utils module

openrl.modules.networks.utils.distributed_utils.reduce_tensor(tensor, n)[源代码]

openrl.modules.networks.utils.distributions module

class openrl.modules.networks.utils.distributions.AddBias(bias)[源代码]

基类:torch.nn.modules.module.Module

forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.utils.distributions.Bernoulli(num_inputs, num_outputs, use_orthogonal=True, gain=0.01)[源代码]

基类:torch.nn.modules.module.Module

forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.utils.distributions.Categorical(num_inputs, num_outputs, use_orthogonal=True, gain=0.01)[源代码]

基类:torch.nn.modules.module.Module

forward(x, action_masks=None)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.utils.distributions.DiagGaussian(num_inputs, num_outputs, use_orthogonal=True, gain=0.01)[源代码]

基类:torch.nn.modules.module.Module

forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.utils.distributions.FixedBernoulli(probs=None, logits=None, validate_args=None)[源代码]

基类:torch.distributions.bernoulli.Bernoulli

entropy()[源代码]

Method to compute the entropy using Bregman divergence of the log normalizer.

log_probs(actions)[源代码]
mode()[源代码]

Returns the mode of the distribution.

class openrl.modules.networks.utils.distributions.FixedCategorical(probs=None, logits=None, validate_args=None)[源代码]

基类:torch.distributions.categorical.Categorical

log_probs(actions)[源代码]
mode()[源代码]

Returns the mode of the distribution.

sample()[源代码]

Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.

class openrl.modules.networks.utils.distributions.FixedNormal(loc, scale, validate_args=None)[源代码]

基类:torch.distributions.normal.Normal

entropy()[源代码]

Method to compute the entropy using Bregman divergence of the log normalizer.

log_probs(actions)[源代码]
mode()[源代码]

Returns the mode of the distribution.

openrl.modules.networks.utils.mix module

class openrl.modules.networks.utils.mix.Flatten(*args, **kwargs)[源代码]

基类:torch.nn.modules.module.Module

forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.utils.mix.MIXBase(cfg, obs_shape, cnn_layers_params=None)[源代码]

基类:torch.nn.modules.module.Module

forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

property output_size

openrl.modules.networks.utils.mlp module

class openrl.modules.networks.utils.mlp.CONVLayer(input_dim, hidden_size, use_orthogonal, activation_id)[源代码]

基类:torch.nn.modules.module.Module

forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class openrl.modules.networks.utils.mlp.MLPBase(cfg, obs_shape, use_attn_internal=False, use_cat_self=True)[源代码]

基类:torch.nn.modules.module.Module

forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

property output_size
class openrl.modules.networks.utils.mlp.MLPLayer(input_dim, hidden_size, layer_N, use_orthogonal, activation_id)[源代码]

基类:torch.nn.modules.module.Module

forward(x)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

openrl.modules.networks.utils.popart module

class openrl.modules.networks.utils.popart.PopArt(input_shape, output_shape, norm_axes=1, beta=0.99999, epsilon=1e-05, device=device(type='cpu'))[源代码]

基类:torch.nn.modules.module.Module

debiased_mean_var()[源代码]
denormalize(input_vector)[源代码]
forward(input_vector)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

normalize(input_vector)[源代码]
reset_parameters()[源代码]
update(input_vector)[源代码]

openrl.modules.networks.utils.rnn module

class openrl.modules.networks.utils.rnn.RNNLayer(inputs_dim, outputs_dim, recurrent_N, use_orthogonal, rnn_type='gru')[源代码]

基类:torch.nn.modules.module.Module

forward(x, hxs, masks)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

rnn_forward(x, h)[源代码]

openrl.modules.networks.utils.running_mean_std module

class openrl.modules.networks.utils.running_mean_std.RunningMeanStd(epsilon: float = 0.0001, shape: Tuple[int, ...] = ())[源代码]

基类:object

combine(other: openrl.modules.networks.utils.running_mean_std.RunningMeanStd) None[源代码]

Combine stats from another RunningMeanStd object.

参数

other -- The other object to combine with.

copy() openrl.modules.networks.utils.running_mean_std.RunningMeanStd[源代码]
返回

Return a copy of the current object.

update(arr: numpy.ndarray) None[源代码]
update_from_moments(batch_mean: numpy.ndarray, batch_var: numpy.ndarray, batch_count: float) None[源代码]

openrl.modules.networks.utils.transformer_act module

openrl.modules.networks.utils.transformer_act.continuous_autoregreesive_act(decoder, obs_rep, obs, batch_size, n_agent, action_dim, tpdv, deterministic=False)[源代码]
openrl.modules.networks.utils.transformer_act.continuous_parallel_act(decoder, obs_rep, obs, action, batch_size, n_agent, action_dim, tpdv)[源代码]
openrl.modules.networks.utils.transformer_act.discrete_autoregreesive_act(decoder, obs_rep, obs, batch_size, n_agent, action_dim, tpdv, action_masks=None, deterministic=False)[源代码]
openrl.modules.networks.utils.transformer_act.discrete_parallel_act(decoder, obs_rep, obs, action, batch_size, n_agent, action_dim, tpdv, action_masks=None)[源代码]

openrl.modules.networks.utils.util module

openrl.modules.networks.utils.util.get_clones(module, N)[源代码]
openrl.modules.networks.utils.util.init(module, weight_init, bias_init, gain=1)[源代码]

openrl.modules.networks.utils.vdn module

class openrl.modules.networks.utils.vdn.VDNBase[源代码]

基类:torch.nn.modules.module.Module

forward(agent_qs)[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Module contents