Shortcuts

Source code for openrl.modules.networks.utils.popart

import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class PopArt(torch.nn.Module): def __init__( self, input_shape, output_shape, norm_axes=1, beta=0.99999, epsilon=1e-5, device=torch.device("cpu"), ): super(PopArt, self).__init__() self.beta = beta self.epsilon = epsilon self.norm_axes = norm_axes self.tpdv = dict(dtype=torch.float32, device=device) self.input_shape = input_shape self.output_shape = output_shape self.weight = nn.Parameter(torch.Tensor(output_shape, input_shape)).to( **self.tpdv ) self.bias = nn.Parameter(torch.Tensor(output_shape)).to(**self.tpdv) self.stddev = nn.Parameter(torch.ones(output_shape), requires_grad=False).to( **self.tpdv ) self.mean = nn.Parameter(torch.zeros(output_shape), requires_grad=False).to( **self.tpdv ) self.mean_sq = nn.Parameter(torch.zeros(output_shape), requires_grad=False).to( **self.tpdv ) self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to( **self.tpdv ) self.reset_parameters()
[docs] def reset_parameters(self): torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) torch.nn.init.uniform_(self.bias, -bound, bound) self.mean.zero_() self.mean_sq.zero_() self.debiasing_term.zero_()
[docs] def forward(self, input_vector): if type(input_vector) == np.ndarray: input_vector = torch.from_numpy(input_vector) input_vector = input_vector.to(**self.tpdv) return F.linear(input_vector, self.weight, self.bias)
[docs] @torch.no_grad() def update(self, input_vector): if type(input_vector) == np.ndarray: input_vector = torch.from_numpy(input_vector) input_vector = input_vector.to(**self.tpdv) old_mean, old_stddev = self.mean, self.stddev batch_mean = input_vector.mean(dim=tuple(range(self.norm_axes))) batch_sq_mean = (input_vector**2).mean(dim=tuple(range(self.norm_axes))) self.mean.mul_(self.beta).add_(batch_mean * (1.0 - self.beta)) self.mean_sq.mul_(self.beta).add_(batch_sq_mean * (1.0 - self.beta)) self.debiasing_term.mul_(self.beta).add_(1.0 * (1.0 - self.beta)) self.stddev = (self.mean_sq - self.mean**2).sqrt().clamp(min=1e-4) self.weight = self.weight * old_stddev / self.stddev self.bias = (old_stddev * self.bias + old_mean - self.mean) / self.stddev
[docs] def debiased_mean_var(self): debiased_mean = self.mean / self.debiasing_term.clamp(min=self.epsilon) debiased_mean_sq = self.mean_sq / self.debiasing_term.clamp(min=self.epsilon) debiased_var = (debiased_mean_sq - debiased_mean**2).clamp(min=1e-2) return debiased_mean, debiased_var
[docs] def normalize(self, input_vector): if type(input_vector) == np.ndarray: input_vector = torch.from_numpy(input_vector) input_vector = input_vector.to(**self.tpdv) mean, var = self.debiased_mean_var() out = (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[ (None,) * self.norm_axes ] return out
[docs] def denormalize(self, input_vector): if type(input_vector) == np.ndarray: input_vector = torch.from_numpy(input_vector) input_vector = input_vector.to(**self.tpdv) mean, var = self.debiased_mean_var() out = ( input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes] ) out = out.cpu().numpy() return out