Shortcuts

Source code for openrl.modules.networks.utils.running_mean_std

from typing import Tuple

import numpy as np


[docs]class RunningMeanStd: def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): """ Calulates the running mean and std of a data stream https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm :param epsilon: helps with arithmetic issues :param shape: the shape of the data stream's output """ self.mean = np.zeros(shape, np.float64) self.var = np.ones(shape, np.float64) self.count = epsilon
[docs] def copy(self) -> "RunningMeanStd": """ :return: Return a copy of the current object. """ new_object = RunningMeanStd(shape=self.mean.shape) new_object.mean = self.mean.copy() new_object.var = self.var.copy() new_object.count = float(self.count) return new_object
[docs] def combine(self, other: "RunningMeanStd") -> None: """ Combine stats from another ``RunningMeanStd`` object. :param other: The other object to combine with. """ self.update_from_moments(other.mean, other.var, other.count)
[docs] def update(self, arr: np.ndarray) -> None: batch_mean = np.mean(arr, axis=0) batch_var = np.var(arr, axis=0) batch_count = arr.shape[0] self.update_from_moments(batch_mean, batch_var, batch_count)
[docs] def update_from_moments( self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: float ) -> None: delta = batch_mean - self.mean tot_count = self.count + batch_count new_mean = self.mean + delta * batch_count / tot_count m_a = self.var * self.count m_b = batch_var * batch_count m_2 = ( m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) ) new_var = m_2 / (self.count + batch_count) new_count = batch_count + self.count self.mean = new_mean self.var = new_var self.count = new_count