Source code for openrl.modules.networks.utils.attention
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .util import get_clones, init
[docs]class Encoder(nn.Module):
def __init__(self, cfg, split_shape, cat_self=True):
super(Encoder, self).__init__()
self._use_orthogonal = cfg.use_orthogonal
self._activation_id = cfg.activation_id
self._attn_N = cfg.attn_N
self._attn_size = cfg.attn_size
self._attn_heads = cfg.attn_heads
self._dropout = cfg.dropout
self._use_average_pool = cfg.use_average_pool
self._cat_self = cat_self
if self._cat_self:
self.embedding = CatSelfEmbedding(
split_shape[1:],
self._attn_size,
self._use_orthogonal,
self._activation_id,
)
else:
self.embedding = Embedding(
split_shape[1:],
self._attn_size,
self._use_orthogonal,
self._activation_id,
)
self.layers = get_clones(
EncoderLayer(
self._attn_size,
self._attn_heads,
self._dropout,
self._use_orthogonal,
self._activation_id,
),
self._attn_N,
)
self.norm = nn.LayerNorm(self._attn_size)
[docs] def forward(self, x, self_idx=-1, mask=None):
x, self_x = self.embedding(x, self_idx)
for i in range(self._attn_N):
x = self.layers[i](x, mask)
x = self.norm(x)
if self._use_average_pool:
x = torch.transpose(x, 1, 2)
x = F.avg_pool1d(x, kernel_size=x.size(-1)).view(x.size(0), -1)
if self._cat_self:
x = torch.cat((x, self_x), dim=-1)
x = x.view(x.size(0), -1)
return x
# [L,[1,2],[1,2],[1,2]]
[docs]def split_obs(obs, split_shape):
start_idx = 0
split_obs = []
for i in range(len(split_shape)):
split_obs.append(
obs[:, start_idx : (start_idx + split_shape[i][0] * split_shape[i][1])]
)
start_idx += split_shape[i][0] * split_shape[i][1]
return split_obs
[docs]class FeedForward(nn.Module):
def __init__(
self, d_model, d_ff=512, dropout=0.0, use_orthogonal=True, activation_id=1
):
super(FeedForward, self).__init__()
# We set d_ff as a default to 2048
active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id]
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
gain = nn.init.calculate_gain(
["tanh", "relu", "leaky_relu", "selu"][activation_id]
)
def init_(m):
return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)
self.linear_1 = nn.Sequential(
init_(nn.Linear(d_model, d_ff)), active_func, nn.LayerNorm(d_ff)
)
self.dropout = nn.Dropout(dropout)
self.linear_2 = init_(nn.Linear(d_ff, d_model))
[docs]def ScaledDotProductAttention(q, k, v, d_k, mask=None, dropout=None):
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
if dropout is not None:
scores = dropout(scores)
output = torch.matmul(scores, v)
return output
[docs]class MultiHeadAttention(nn.Module):
def __init__(self, heads, d_model, dropout=0.0, use_orthogonal=True):
super(MultiHeadAttention, self).__init__()
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
def init_(m):
return init(m, init_method, lambda x: nn.init.constant_(x, 0))
self.d_model = d_model
self.d_k = d_model // heads
self.h = heads
self.q_linear = init_(nn.Linear(d_model, d_model))
self.v_linear = init_(nn.Linear(d_model, d_model))
self.k_linear = init_(nn.Linear(d_model, d_model))
self.dropout = nn.Dropout(dropout)
self.out = init_(nn.Linear(d_model, d_model))
[docs] def forward(self, q, k, v, mask=None):
bs = q.size(0)
# perform linear operation and split into h heads
k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
# transpose to get dimensions bs * h * sl * d_model
k = k.transpose(1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# calculate attention
scores = ScaledDotProductAttention(q, k, v, self.d_k, mask, self.dropout)
# concatenate heads and put through final linear layer
concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
output = self.out(concat)
return output
[docs]class EncoderLayer(nn.Module):
def __init__(
self,
d_model,
heads,
dropout=0.0,
use_orthogonal=True,
activation_id=False,
d_ff=512,
use_FF=False,
):
super(EncoderLayer, self).__init__()
self._use_FF = use_FF
self.norm_1 = nn.LayerNorm(d_model)
self.norm_2 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(heads, d_model, dropout, use_orthogonal)
self.ff = FeedForward(d_model, d_ff, dropout, use_orthogonal, activation_id)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
[docs] def forward(self, x, mask):
x2 = self.norm_1(x)
x = x + self.dropout_1(self.attn(x2, x2, x2, mask))
if self._use_FF:
x2 = self.norm_2(x)
x = x + self.dropout_2(self.ff(x2))
return x
[docs]class CatSelfEmbedding(nn.Module):
def __init__(self, split_shape, d_model, use_orthogonal=True, activation_id=1):
super(CatSelfEmbedding, self).__init__()
self.split_shape = split_shape
active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id]
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
gain = nn.init.calculate_gain(
["tanh", "relu", "leaky_relu", "selu"][activation_id]
)
def init_(m):
return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)
for i in range(len(split_shape)):
if i == (len(split_shape) - 1):
setattr(
self,
"fc_" + str(i),
nn.Sequential(
init_(nn.Linear(split_shape[i][1], d_model)),
active_func,
nn.LayerNorm(d_model),
),
)
else:
setattr(
self,
"fc_" + str(i),
nn.Sequential(
init_(
nn.Linear(split_shape[i][1] + split_shape[-1][1], d_model)
),
active_func,
nn.LayerNorm(d_model),
),
)
[docs] def forward(self, x, self_idx=-1):
x = split_obs(x, self.split_shape)
N = len(x)
x1 = []
self_x = x[self_idx]
for i in range(N - 1):
K = self.split_shape[i][0]
L = self.split_shape[i][1]
for j in range(K):
# torch.cat((x[i][:, (L * j) : (L * j + L)], self_x), dim=-1)
# exec("x1.append(self.fc_{}(temp))".format(i))
temp = torch.cat((x[i][:, (L * j) : (L * j + L)], self_x), dim=-1)
x1.append(getattr(self, "fc_" + str(i))(temp))
x1.append(getattr(self, "fc_" + str(N - 1))(self_x))
# x[self_idx]
# exec("x1.append(self.fc_{}(temp))".format(N - 1))
out = torch.stack(x1, 1)
return out, self_x
[docs]class Embedding(nn.Module):
def __init__(self, split_shape, d_model, use_orthogonal=True, activation_id=1):
super(Embedding, self).__init__()
self.split_shape = split_shape
active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id]
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
gain = nn.init.calculate_gain(
["tanh", "relu", "leaky_relu", "selu"][activation_id]
)
def init_(m):
return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)
for i in range(len(split_shape)):
setattr(
self,
"fc_" + str(i),
nn.Sequential(
init_(nn.Linear(split_shape[i][1], d_model)),
active_func,
nn.LayerNorm(d_model),
),
)
[docs] def forward(self, x, self_idx=None):
x = split_obs(x, self.split_shape)
N = len(x)
x1 = []
for i in range(N):
K = self.split_shape[i][0]
L = self.split_shape[i][1]
for j in range(K):
# x[i][:, (L * j) : (L * j + L)]
# exec("x1.append(self.fc_{}(temp))".format(i))
temp = x[i][:, (L * j) : (L * j + L)]
x1.append(getattr(self, "fc_" + str(i))(temp))
out = torch.stack(x1, 1)
if self_idx is None:
return out, None
else:
return out, x[self_idx]