Source code for openrl.modules.networks.utils.cnn
import torch.nn as nn
from .util import init
[docs]class CNNLayer(nn.Module):
def __init__(
self,
obs_shape,
hidden_size,
use_orthogonal,
activation_id,
kernel_size=3,
stride=1,
):
super(CNNLayer, self).__init__()
[nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id]
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
gain = nn.init.calculate_gain(
["tanh", "relu", "leaky_relu", "selu"][activation_id]
)
def init_(m):
return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)
input_channel = obs_shape[0]
input_h = obs_shape[1]
input_w = obs_shape[2]
# Calculate the input size for Flatten Layer
flatten_h, flatten_w = self.calc_flatten_size(input_h, input_w, 8, 4)
flatten_h, flatten_w = self.calc_flatten_size(flatten_h, flatten_w, 4, 2)
flatten_h, flatten_w = self.calc_flatten_size(flatten_h, flatten_w, 3, 1)
# self.cnn = nn.Sequential(
# init_(nn.Conv2d(in_channels=input_channel, out_channels=hidden_size//2, kernel_size=kernel_size, stride=stride)), active_func,
# Flatten(),
# init_(nn.Linear(hidden_size//2 * (input_width-kernel_size+stride) * (input_height-kernel_size+stride), hidden_size)), active_func,
# init_(nn.Linear(hidden_size, hidden_size)), active_func)
# only for atari
self.cnn = nn.Sequential(
init_(nn.Conv2d(input_channel, 32, 8, stride=4)),
nn.ReLU(),
init_(nn.Conv2d(32, 64, 4, stride=2)),
nn.ReLU(),
init_(nn.Conv2d(64, 32, 3, stride=1)),
nn.ReLU(),
Flatten(),
# Lee for retro
init_(nn.Linear(32 * flatten_h * flatten_w, hidden_size)),
# init_(nn.Linear(32 * 7 * 7, hidden_size)),
nn.ReLU(),
)
[docs] def calc_flatten_size(self, h, w, filter, stride):
h = int((h - filter) / stride) + 1
w = int((w - filter) / stride) + 1
return h, w
[docs]class CNNBase(nn.Module):
def __init__(self, cfg, obs_shape):
super(CNNBase, self).__init__()
self._use_orthogonal = cfg.use_orthogonal
self._activation_id = cfg.activation_id
self.hidden_size = cfg.hidden_size
self.cnn = CNNLayer(
obs_shape, self.hidden_size, self._use_orthogonal, self._activation_id
)
@property
def output_size(self):
return self.hidden_size