Source code for openrl.modules.networks.utils.rnn
import torch
import torch.nn as nn
[docs]class RNNLayer(nn.Module):
def __init__(
self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal, rnn_type="gru"
):
super(RNNLayer, self).__init__()
self._recurrent_N = recurrent_N
self._use_orthogonal = use_orthogonal
self.rnn_type = rnn_type
if rnn_type == "gru":
self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self._recurrent_N)
elif rnn_type == "lstm":
self.rnn = nn.LSTM(inputs_dim, outputs_dim, num_layers=self._recurrent_N)
else:
raise NotImplementedError(f"RNN type {rnn_type} has not been implemented.")
for name, param in self.rnn.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0)
elif "weight" in name:
if self._use_orthogonal:
nn.init.orthogonal_(param)
else:
nn.init.xavier_uniform_(param)
self.norm = nn.LayerNorm(outputs_dim)
[docs] def rnn_forward(self, x, h):
if self.rnn_type == "lstm":
h = torch.split(h, h.shape[-1] // 2, dim=-1)
h = (h[0].contiguous(), h[1].contiguous())
x_, h_ = self.rnn(x, h)
if self.rnn_type == "lstm":
h_ = torch.cat(h_, -1)
return x_, h_
[docs] def forward(self, x, hxs, masks):
if x.size(0) == hxs.size(0):
x, hxs = self.rnn_forward(
x.unsqueeze(0),
(hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1))
.transpose(0, 1)
.contiguous(),
)
# x= self.gru(x.unsqueeze(0))
x = x.squeeze(0)
hxs = hxs.transpose(0, 1)
else:
# x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
N = hxs.size(0)
T = int(x.size(0) / N)
# unflatten
x = x.view(T, N, x.size(1))
# Same deal with masks
masks = masks.view(T, N)
# Let's figure out which steps in the sequence have a zero for any agent
# We will always assume t=0 has a zero in it as that makes the logic cleaner
has_zeros = (masks[1:] == 0.0).any(dim=-1).nonzero().squeeze().cpu()
# +1 to correct the masks[1:]
if has_zeros.dim() == 0:
# Deal with scalar
has_zeros = [has_zeros.item() + 1]
else:
has_zeros = (has_zeros + 1).numpy().tolist()
# add t=0 and t=T to the list
has_zeros = [0] + has_zeros + [T]
hxs = hxs.transpose(0, 1)
outputs = []
for i in range(len(has_zeros) - 1):
# We can now process steps that don't have any zeros in masks together!
# This is much faster
start_idx = has_zeros[i]
end_idx = has_zeros[i + 1]
temp = (
hxs
* masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)
).contiguous()
rnn_scores, hxs = self.rnn_forward(x[start_idx:end_idx], temp)
outputs.append(rnn_scores)
# assert len(outputs) == T
# x is a (T, N, -1) tensor
x = torch.cat(outputs, dim=0)
# flatten
x = x.reshape(T * N, -1)
hxs = hxs.transpose(0, 1)
x = self.norm(x)
return x, hxs