Source code for openrl.modules.networks.utils.mlp
import torch
import torch.nn as nn
from .attention import Encoder
from .util import get_clones, init
[docs]class MLPLayer(nn.Module):
def __init__(self, input_dim, hidden_size, layer_N, use_orthogonal, activation_id):
super(MLPLayer, self).__init__()
self._layer_N = layer_N
active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id]
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
gain = nn.init.calculate_gain(
["tanh", "relu", "leaky_relu", "selu"][activation_id]
)
def init_(m):
return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)
self.fc1 = nn.Sequential(
init_(nn.Linear(input_dim, hidden_size)),
active_func,
nn.LayerNorm(hidden_size),
)
if self._layer_N > 1:
self.fc_h = nn.Sequential(
init_(nn.Linear(hidden_size, hidden_size)),
active_func,
nn.LayerNorm(hidden_size),
)
self.fc2 = get_clones(self.fc_h, self._layer_N - 1)
self.fc3 = nn.Sequential(
init_(nn.Linear(hidden_size, hidden_size)),
nn.LayerNorm(hidden_size),
)
[docs] def forward(self, x):
x = self.fc1(x)
for i in range(self._layer_N - 1):
x = self.fc2[i](x)
x = self.fc3(x)
return x
[docs]class CONVLayer(nn.Module):
def __init__(self, input_dim, hidden_size, use_orthogonal, activation_id):
super(CONVLayer, self).__init__()
active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id]
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
gain = nn.init.calculate_gain(
["tanh", "relu", "leaky_relu", "selu"][activation_id]
)
def init_(m):
return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)
self.conv = nn.Sequential(
init_(
nn.Conv1d(
in_channels=input_dim,
out_channels=hidden_size // 4,
kernel_size=3,
stride=2,
padding=0,
)
),
active_func, # nn.BatchNorm1d(hidden_size//4),
init_(
nn.Conv1d(
in_channels=hidden_size // 4,
out_channels=hidden_size // 2,
kernel_size=3,
stride=1,
padding=1,
)
),
active_func, # nn.BatchNorm1d(hidden_size//2),
init_(
nn.Conv1d(
in_channels=hidden_size // 2,
out_channels=hidden_size,
kernel_size=3,
stride=1,
padding=1,
)
),
active_func,
) # , nn.BatchNorm1d(hidden_size))
[docs]class MLPBase(nn.Module):
def __init__(self, cfg, obs_shape, use_attn_internal=False, use_cat_self=True):
super(MLPBase, self).__init__()
self._use_feature_normalization = cfg.use_feature_normalization
self._use_orthogonal = cfg.use_orthogonal
self._activation_id = cfg.activation_id
self._use_attn = cfg.use_attn
self._use_attn_internal = use_attn_internal
self._use_average_pool = cfg.use_average_pool
self._use_conv1d = cfg.use_conv1d
self._stacked_frames = cfg.stacked_frames
self._layer_N = 0 if cfg.use_single_network else cfg.layer_N
self._attn_size = cfg.attn_size
self.hidden_size = cfg.hidden_size
obs_dim = obs_shape[0]
if self._use_feature_normalization:
self.feature_norm = nn.LayerNorm(obs_dim)
if self._use_attn and self._use_attn_internal:
if self._use_average_pool:
if use_cat_self:
inputs_dim = self._attn_size + obs_shape[-1][1]
else:
inputs_dim = self._attn_size
else:
split_inputs_dim = 0
split_shape = obs_shape[1:]
for i in range(len(split_shape)):
split_inputs_dim += split_shape[i][0]
inputs_dim = split_inputs_dim * self._attn_size
self.attn = Encoder(cfg, obs_shape, use_cat_self)
self.attn_norm = nn.LayerNorm(inputs_dim)
else:
inputs_dim = obs_dim
if self._use_conv1d:
self.conv = CONVLayer(
self._stacked_frames,
self.hidden_size,
self._use_orthogonal,
self._activation_id,
)
random_x = torch.FloatTensor(
1, self._stacked_frames, inputs_dim // self._stacked_frames
)
random_out = self.conv(random_x)
assert len(random_out.shape) == 3
inputs_dim = random_out.size(-1) * random_out.size(-2)
self.mlp = MLPLayer(
inputs_dim,
self.hidden_size,
self._layer_N,
self._use_orthogonal,
self._activation_id,
)
[docs] def forward(self, x):
if self._use_feature_normalization:
x = self.feature_norm(x)
if self._use_attn and self._use_attn_internal:
x = self.attn(x, self_idx=-1)
x = self.attn_norm(x)
if self._use_conv1d:
batch_size = x.size(0)
x = x.view(batch_size, self._stacked_frames, -1)
x = self.conv(x)
x = x.view(batch_size, -1)
x = self.mlp(x)
return x
@property
def output_size(self):
return self.hidden_size