Shortcuts

Source code for openrl.modules.networks.utils.mix

import numpy as np
import torch
import torch.nn as nn

from .util import init


[docs]class Flatten(nn.Module):
[docs] def forward(self, x): return x.reshape(x.size(0), -1)
[docs]class MIXBase(nn.Module): def __init__(self, cfg, obs_shape, cnn_layers_params=None): super(MIXBase, self).__init__() self._use_orthogonal = cfg.use_orthogonal self._activation_id = cfg.activation_id self._use_maxpool2d = cfg.use_maxpool2d self.hidden_size = cfg.hidden_size self.cnn_keys = [] self.embed_keys = [] self.mlp_keys = [] self.n_cnn_input = 0 self.n_embed_input = 0 self.n_mlp_input = 0 for key in obs_shape: if ( obs_shape[key].__class__.__name__ == "Box" or obs_shape[key].__class__.__name__ == "MultiBinary" ): key_obs_shape = obs_shape[key].shape if len(key_obs_shape) == 3: self.cnn_keys.append(key) else: # if "orientation" in key: # self.embed_keys.append(key) # else: self.mlp_keys.append(key) else: raise NotImplementedError if len(self.cnn_keys) > 0: self.cnn = self._build_cnn_model( obs_shape, cnn_layers_params, self.hidden_size, self._use_orthogonal, self._activation_id, ) if len(self.embed_keys) > 0: self.embed = self._build_embed_model(obs_shape) if len(self.mlp_keys) > 0: self.mlp = self._build_mlp_model( obs_shape, self.hidden_size, self._use_orthogonal, self._activation_id )
[docs] def forward(self, x): out_x = None if len(self.cnn_keys) > 0: cnn_input = self._build_cnn_input(x) out_x = self.cnn(cnn_input) if len(self.embed_keys) > 0: embed_input = self._build_embed_input(x) embed_x = self.embed(embed_input.long()).view(embed_input.size(0), -1) if out_x is not None: out_x = torch.cat([out_x, embed_x], dim=1) else: out_x = embed_x if len(self.mlp_keys) > 0: mlp_input = self._build_mlp_input(x) mlp_x = self.mlp(mlp_input).view(mlp_input.size(0), -1) if out_x is not None: out_x = torch.cat([out_x, mlp_x], dim=1) # ! wrong else: out_x = mlp_x return out_x
def _build_cnn_model( self, obs_shape, cnn_layers_params, hidden_size, use_orthogonal, activation_id ): if cnn_layers_params is None: cnn_layers_params = [(32, 8, 4, 0), (64, 4, 2, 0), (64, 3, 1, 0)] else: def _convert(params): output = [] for line in params.split(" "): output.append(tuple(map(int, line.split(",")))) return output cnn_layers_params = _convert(cnn_layers_params) active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id] init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] gain = nn.init.calculate_gain( ["tanh", "relu", "leaky_relu", "selu"][activation_id] ) def init_(m): return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain) for key in self.cnn_keys: if key in ["rgb", "depth", "image", "occupy_image"]: self.n_cnn_input += obs_shape[key].shape[2] cnn_dims = np.array(obs_shape[key].shape[:2], dtype=np.float32) elif key in [ "global_map", "local_map", "global_obs", "global_merge_obs", "global_merge_goal", "gt_map", ]: self.n_cnn_input += obs_shape[key].shape[0] cnn_dims = np.array(obs_shape[key].shape[1:3], dtype=np.float32) else: raise NotImplementedError cnn_layers = [] prev_out_channels = None for i, (out_channels, kernel_size, stride, padding) in enumerate( cnn_layers_params ): if self._use_maxpool2d and i != len(cnn_layers_params) - 1: cnn_layers.append(nn.MaxPool2d(2)) if i == 0: in_channels = self.n_cnn_input else: in_channels = prev_out_channels cnn_layers.append( init_( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, ) ) ) # if i != len(cnn_layers_params) - 1: cnn_layers.append(active_func) prev_out_channels = out_channels for i, (_, kernel_size, stride, padding) in enumerate(cnn_layers_params): if self._use_maxpool2d and i != len(cnn_layers_params) - 1: cnn_dims = self._maxpool_output_dim( dimension=cnn_dims, dilation=np.array([1, 1], dtype=np.float32), kernel_size=np.array([2, 2], dtype=np.float32), stride=np.array([2, 2], dtype=np.float32), ) cnn_dims = self._cnn_output_dim( dimension=cnn_dims, padding=np.array([padding, padding], dtype=np.float32), dilation=np.array([1, 1], dtype=np.float32), kernel_size=np.array([kernel_size, kernel_size], dtype=np.float32), stride=np.array([stride, stride], dtype=np.float32), ) cnn_layers += [ Flatten(), init_( nn.Linear( cnn_layers_params[-1][0] * cnn_dims[0] * cnn_dims[1], hidden_size ) ), active_func, nn.LayerNorm(hidden_size), ] return nn.Sequential(*cnn_layers) def _build_embed_model(self, obs_shape): self.embed_dim = 0 for key in self.embed_keys: self.n_embed_input = 72 self.n_embed_output = 8 self.embed_dim += np.prod(obs_shape[key].shape) return nn.Embedding(self.n_embed_input, self.n_embed_output) def _build_mlp_model(self, obs_shape, hidden_size, use_orthogonal, activation_id): active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id] init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] gain = nn.init.calculate_gain( ["tanh", "relu", "leaky_relu", "selu"][activation_id] ) def init_(m): return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain) for key in self.mlp_keys: self.n_mlp_input += np.prod(obs_shape[key].shape) return nn.Sequential( init_(nn.Linear(self.n_mlp_input, hidden_size)), active_func, nn.LayerNorm(hidden_size), ) def _maxpool_output_dim(self, dimension, dilation, kernel_size, stride): """Calculates the output height and width based on the input height and width to the convolution layer. ref: https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d """ assert len(dimension) == 2 out_dimension = [] for i in range(len(dimension)): out_dimension.append( int( np.floor( ( (dimension[i] - dilation[i] * (kernel_size[i] - 1) - 1) / stride[i] ) + 1 ) ) ) return tuple(out_dimension) def _cnn_output_dim(self, dimension, padding, dilation, kernel_size, stride): """Calculates the output height and width based on the input height and width to the convolution layer. ref: https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d """ assert len(dimension) == 2 out_dimension = [] for i in range(len(dimension)): out_dimension.append( int( np.floor( ( ( dimension[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1 ) / stride[i] ) + 1 ) ) ) return tuple(out_dimension) def _build_cnn_input(self, obs): cnn_input = [] for key in self.cnn_keys: if key in ["rgb", "depth", "image", "occupy_image"]: cnn_input.append(obs[key].permute(0, 3, 1, 2) / 255.0) elif key in [ "global_map", "local_map", "global_obs", "global_merge_obs", "global_merge_goal", "gt_map", ]: cnn_input.append(obs[key]) else: raise NotImplementedError cnn_input = torch.cat(cnn_input, dim=1) return cnn_input def _build_embed_input(self, obs): embed_input = [] for key in self.embed_keys: embed_input.append(obs[key].view(obs[key].size(0), -1)) embed_input = torch.cat(embed_input, dim=1) return embed_input def _build_mlp_input(self, obs): mlp_input = [] for key in self.mlp_keys: mlp_input.append(obs[key].view(obs[key].size(0), -1)) mlp_input = torch.cat(mlp_input, dim=1) return mlp_input @property def output_size(self): output_size = 0 if len(self.cnn_keys) > 0: output_size += self.hidden_size if len(self.embed_keys) > 0: output_size += 8 * self.embed_dim if len(self.mlp_keys) > 0: output_size += self.hidden_size return output_size