Source code for openrl.utils.type_aliases
# Modifed from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/type_aliases.py
"""Common aliases for type hints"""
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
import gym
import numpy as np
import torch as th
from openrl.envs import vec_env
from openrl.utils.callbacks import callbacks
GymEnv = Union[gym.Env, vec_env.BaseVecEnv]
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
GymStepReturn = Union[
Tuple[GymObs, float, bool, Dict], Tuple[GymObs, float, bool, bool, Dict]
]
TensorDict = Dict[Union[str, int], th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[
None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback
]
[docs]class AgentActor(Protocol):
[docs] def act(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param deterministic: Whether to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""