Shortcuts

openrl.modules.networks.policy_value_network 源代码

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

import torch
import torch.nn as nn

from openrl.buffers.utils.util import get_policy_obs_space
from openrl.modules.networks.utils.act import ACTLayer
from openrl.modules.networks.utils.cnn import CNNBase
from openrl.modules.networks.utils.mlp import MLPBase, MLPLayer
from openrl.modules.networks.utils.popart import PopArt
from openrl.modules.networks.utils.rnn import RNNLayer
from openrl.modules.networks.utils.util import init
from openrl.utils.util import check_v2 as check


[文档]class PolicyValueNetwork(nn.Module): def __init__( self, cfg, obs_space, critic_obs_space, action_space, device=torch.device("cpu"), use_half=False, ): super(PolicyValueNetwork, self).__init__() self._gain = cfg.gain self._use_orthogonal = cfg.use_orthogonal self._activation_id = cfg.activation_id self._recurrent_N = cfg.recurrent_N self._use_naive_recurrent_policy = cfg.use_naive_recurrent_policy self._use_recurrent_policy = cfg.use_recurrent_policy self._concat_obs_as_critic_obs = cfg.concat_obs_as_critic_obs self._use_popart = cfg.use_popart self.hidden_size = cfg.hidden_size self.device = device self.use_half = use_half self.tpdv = dict(dtype=torch.float32, device=device) init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][ self._use_orthogonal ] # obs space policy_obs_shape = get_policy_obs_space(obs_space) self.obs_prep = ( CNNBase(cfg, policy_obs_shape) if len(policy_obs_shape) == 3 else MLPBase( cfg, policy_obs_shape, use_attn_internal=cfg.use_attn_internal, use_cat_self=True, ) ) # critic_obs_shape = get_critic_obs_space(critic_obs_space) # self.critic_obs_prep = ( # CNNBase(cfg, critic_obs_shape) # if len(critic_obs_shape) == 3 # else MLPBase( # cfg, # critic_obs_shape, # use_attn_internal=True, # use_cat_self=cfg.use_cat_self, # ) # ) # self.critic_obs_prep = self.obs_prep # common layer self.common = MLPLayer( self.hidden_size, self.hidden_size, layer_N=0, use_orthogonal=self._use_orthogonal, activation_id=self._activation_id, ) if self._use_naive_recurrent_policy or self._use_recurrent_policy: self.rnn = RNNLayer( self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal, ) def init_(m): return init(m, init_method, lambda x: nn.init.constant_(x, 0)) input_size = self.hidden_size # value if self._use_popart: self.v_out = init_(PopArt(input_size, 1, device=device)) else: self.v_out = init_(nn.Linear(input_size, 1)) # action self.act = ACTLayer( action_space, self.hidden_size, self._use_orthogonal, self._gain ) if use_half: self.half() self.to(self.device)
[文档] def get_actions( self, obs, rnn_states, masks, available_actions=None, deterministic=False ): obs = check(obs, self.use_half, self.tpdv) rnn_states = check(rnn_states, self.use_half, self.tpdv) masks = check(masks, self.use_half, self.tpdv) if available_actions is not None: available_actions = check(available_actions, self.use_half, self.tpdv) x = obs x = self.obs_prep(x) # common actor_features = self.common(x) if self._use_naive_recurrent_policy or self._use_recurrent_policy: actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) actions, action_log_probs = self.act( actor_features, available_actions, deterministic ) return actions, action_log_probs, rnn_states
[文档] def evaluate_actions( self, obs, rnn_states, action, masks, available_actions, active_masks=None ): obs = check(obs, self.use_half, self.tpdv) rnn_states = check(rnn_states, self.use_half, self.tpdv) action = check(action, self.use_half, self.tpdv) masks = check(masks, self.use_half, self.tpdv) if available_actions is not None: available_actions = check(available_actions, self.use_half, self.tpdv) if active_masks is not None: active_masks = check(active_masks, self.use_half, self.tpdv) x = obs x = self.obs_prep(x) actor_features = self.common(x) if self._use_naive_recurrent_policy or self._use_recurrent_policy: actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) action_log_probs, dist_entropy = self.act.evaluate_actions( actor_features, action, available_actions, active_masks ) return action_log_probs, dist_entropy
[文档] def get_values(self, critic_obs, rnn_states, masks): critic_obs = check(critic_obs, self.use_half, self.tpdv) rnn_states = check(rnn_states, self.use_half, self.tpdv) masks = check(masks, self.use_half, self.tpdv) share_x = critic_obs share_x = self.critic_obs_prep(share_x) critic_features = self.common(share_x) if self._use_naive_recurrent_policy or self._use_recurrent_policy: critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks) values = self.v_out(critic_features) return values, rnn_states