Shortcuts

Source code for openrl.modules.networks.ddpg_network

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F

from openrl.buffers.utils.util import get_critic_obs_space, get_policy_obs_space
from openrl.modules.networks.base_policy_network import BasePolicyNetwork
from openrl.modules.networks.base_value_network import BaseValueNetwork
from openrl.modules.networks.utils.cnn import CNNBase
from openrl.modules.networks.utils.mix import MIXBase
from openrl.modules.networks.utils.mlp import MLPBase
from openrl.modules.networks.utils.rnn import RNNLayer
from openrl.modules.networks.utils.util import init
from openrl.utils.util import check_v2 as check


[docs]class ActorNetwork(BasePolicyNetwork): def __init__( self, cfg, input_space, action_space, device=torch.device("cpu"), use_half=False, extra_args=None, ) -> None: super().__init__(cfg, device) self.hidden_size = cfg.hidden_size self.action_space = action_space self.use_half = use_half self._use_orthogonal = cfg.use_orthogonal self.tpdv = dict(dtype=torch.float32, device=device) init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][ self._use_orthogonal ] obs_shape = get_policy_obs_space(input_space) if "Dict" in obs_shape.__class__.__name__: self._mixed_obs = True self.base = MIXBase(cfg, obs_shape, cnn_layers_params=cfg.cnn_layers_params) else: self._mixed_obs = False self.base = ( CNNBase(cfg, obs_shape) if len(obs_shape) == 3 else MLPBase( cfg, obs_shape, use_attn_internal=cfg.use_attn_internal, use_cat_self=True, ) ) input_size = self.base.output_size def init_(m): return init(m, init_method, lambda x: nn.init.constant_(x, 0)) if isinstance(self.action_space, gym.spaces.discrete.Discrete): self.actor_out = init_(nn.Linear(input_size, action_space.n)) elif isinstance(self.action_space, gym.spaces.box.Box): self.actor_out = init_(nn.Linear(input_size, action_space.shape[0])) else: raise NotImplementedError("This type of game has not been implemented.") if use_half: self.half() self.to(device)
[docs] def forward(self, obs): if self._mixed_obs: for key in obs.keys(): obs[key] = check(obs[key]).to(**self.tpdv) else: obs = check(obs).to(**self.tpdv) features = self.base(obs) if isinstance(self.action_space, gym.spaces.discrete.Discrete): features = F.relu(features) action = self.actor_out(features) elif isinstance(self.action_space, gym.spaces.box.Box): action = self.actor_out(features) action = F.tanh(action) action = (action + 1) / 2 * ( torch.tensor(self.action_space.high) - torch.tensor(self.action_space.low) ) + torch.tensor(self.action_space.low) else: raise NotImplementedError("This type of game has not been implemented.") return action
[docs]class CriticNetwork(BaseValueNetwork): def __init__( self, cfg, input_space, action_space, device=torch.device("cpu"), use_half=False, extra_args=None, ) -> None: super().__init__(cfg, device) self.hidden_size = cfg.hidden_size self._gain = cfg.gain self._use_orthogonal = cfg.use_orthogonal self._activation_id = cfg.activation_id self._use_policy_active_masks = cfg.use_policy_active_masks self._use_naive_recurrent_policy = cfg.use_naive_recurrent_policy self._use_recurrent_policy = cfg.use_recurrent_policy self._recurrent_N = cfg.recurrent_N self.use_half = use_half self.tpdv = dict(dtype=torch.float32, device=device) init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][ self._use_orthogonal ] state_shape = get_critic_obs_space(input_space)[0] action_shape = get_policy_obs_space(action_space)[0] input_shape = (state_shape + action_shape,) if "Dict" in input_shape.__class__.__name__: self._mixed_obs = True self.input_base = MIXBase( cfg, input_shape, cnn_layers_params=cfg.cnn_layers_params ) else: self._mixed_obs = False self.input_base = MLPBase( cfg, input_shape, use_attn_internal=cfg.use_attn_internal, use_cat_self=True, ) input_size = self.input_base.output_size if self._use_naive_recurrent_policy or self._use_recurrent_policy: self.rnn = RNNLayer( input_size, self.hidden_size, self._recurrent_N, self._use_orthogonal, rnn_type=cfg.rnn_type, ) def init_(m): return init(m, init_method, lambda x: nn.init.constant_(x, 0)) self.critic_out = init_(nn.Linear(input_size, 1)) if use_half: self.half() self.to(device)
[docs] def forward(self, state, action, rnn_states, masks): if self._mixed_obs: for key in state.keys(): state[key] = check(state[key]).to(**self.tpdv) for key in action.keys(): action[key] = check(action[key]).to(**self.tpdv) else: state = check(state).to(**self.tpdv) action = check(action).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) input = torch.cat((state, action), 1) # features = F.relu(self.input_base(input)) features = self.input_base(input) if self._use_naive_recurrent_policy or self._use_recurrent_policy: features, rnn_states = self.rnn(features, rnn_states, masks) critic_out = self.critic_out(features) return critic_out, rnn_states
[docs]class CriticNetwork_v0(BaseValueNetwork): def __init__( self, cfg, input_space, action_space, device=torch.device("cpu"), use_half=False, ) -> None: super().__init__(cfg, device) self.hidden_size = cfg.hidden_size self._gain = cfg.gain self._use_orthogonal = cfg.use_orthogonal self._activation_id = cfg.activation_id self._use_policy_active_masks = cfg.use_policy_active_masks self._use_naive_recurrent_policy = cfg.use_naive_recurrent_policy self._use_recurrent_policy = cfg.use_recurrent_policy self._recurrent_N = cfg.recurrent_N self.use_half = use_half self.tpdv = dict(dtype=torch.float32, device=device) init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][ self._use_orthogonal ] state_shape = get_critic_obs_space(input_space) action_shape = get_policy_obs_space(action_space) if "Dict" in state_shape.__class__.__name__: self._mixed_obs = True self.state_base = MIXBase( cfg, state_shape, cnn_layers_params=cfg.cnn_layers_params ) else: self._mixed_obs = False self.state_base = ( CNNBase(cfg, state_shape) if len(state_shape) == 3 else MLPBase( cfg, state_shape, use_attn_internal=cfg.use_attn_internal, use_cat_self=True, ) ) if "Dict" in action_shape.__class__.__name__: self._mixed_obs = True self.action_base = MIXBase( cfg, action_shape, cnn_layers_params=cfg.cnn_layers_params ) else: self._mixed_obs = False self.action_base = MLPBase( cfg, action_shape, use_attn_internal=cfg.use_attn_internal, use_cat_self=True, ) # state_input_size = self.state_base.output_size # action_input_size = self.action_base.output_size # input_size = state_input_size + action_input_size input_size = self.state_base.output_size if self._use_naive_recurrent_policy or self._use_recurrent_policy: self.rnn = RNNLayer( input_size, self.hidden_size, self._recurrent_N, self._use_orthogonal, rnn_type=cfg.rnn_type, ) # input_size = self.hidden_size def init_(m): return init(m, init_method, lambda x: nn.init.constant_(x, 0)) self.combine_feature_layer_1 = init_(nn.Linear(input_size, self.hidden_size)) self.combine_feature_layer_2 = init_( nn.Linear(self.hidden_size, self.hidden_size) ) self.combine_feature_layer_3 = init_(nn.Linear(self.hidden_size, 1)) if use_half: self.half() self.to(device)
[docs] def forward(self, state, action, rnn_states, masks): if self._mixed_obs: for key in state.keys(): state[key] = check(state[key]).to(**self.tpdv) for key in action.keys(): action[key] = check(action[key]).to(**self.tpdv) else: state = check(state).to(**self.tpdv) action = check(action).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) state_feature = F.relu(self.state_base(state)) action_feature = F.relu(self.action_base(action)) features = state_feature + action_feature if self._use_naive_recurrent_policy or self._use_recurrent_policy: features, rnn_states = self.rnn(features, rnn_states, masks) critic_out = self.combine_feature_layer_1(features) critic_out = F.relu(self.combine_feature_layer_2(critic_out)) critic_out = self.combine_feature_layer_3(critic_out) return critic_out, rnn_states