Shortcuts

Source code for openrl.buffers.utils.obs_data

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

import numpy as np
from treevalue import TreeValue, reduce_


[docs]class ObsData(TreeValue):
[docs] def flatten(self): return reduce_(self, lambda **kwargs: np.concatenate(list(kwargs.values())))
[docs] @staticmethod def prepare_input(obs): if isinstance(obs, dict): result = {} for str_key in obs.keys(): result[str_key] = np.concatenate(obs[str_key]) else: result = np.concatenate(obs, axis=0) return result
[docs] def step_batch(self, step): return_dict = {} for str_key in self.keys(): return_dict[str_key] = np.concatenate(self[str_key][step]) return return_dict
[docs] def all_batch(self, min, max): return_dict = {} for str_key in self.keys(): return_dict[str_key] = self[str_key][min:max].reshape( (-1, *self[str_key].shape[3:]) ) return return_dict
def __getitem__(self, key): if isinstance(key, int): return self.step_batch(key) else: return super().__getitem__(key)
[docs] def step_flatten(self, step): reduce_( self, lambda **kwargs: np.concatenate( [value[step] for value in kwargs.values()], -1 ), )