Source code for openrl.modules.networks.sac_network
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from openrl.buffers.utils.util import get_critic_obs_space, get_policy_obs_space
from openrl.modules.networks.base_policy_network import BasePolicyNetwork
from openrl.modules.networks.base_value_network import BaseValueNetwork
from openrl.modules.networks.ddpg_network import ActorNetwork
from openrl.modules.networks.utils.cnn import CNNBase
from openrl.modules.networks.utils.mix import MIXBase
from openrl.modules.networks.utils.mlp import MLPBase
from openrl.modules.networks.utils.rnn import RNNLayer
from openrl.modules.networks.utils.util import init
from openrl.utils.util import check_v2 as check
[docs]class SACActorNetwork(ActorNetwork):
def __init__(
self,
cfg,
input_space,
action_space,
device=torch.device("cpu"),
use_half=False,
extra_args=None,
log_std_min=-20,
log_std_max=2,
) -> None:
super().__init__(cfg, input_space, action_space, device, use_half, extra_args)
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][
self._use_orthogonal
]
input_size = self.base.output_size
def init_(m):
return init(m, init_method, lambda x: nn.init.constant_(x, 0))
if isinstance(self.action_space, gym.spaces.box.Box):
self.actor_out = init_(nn.Linear(input_size, action_space.shape[0] * 2))
else:
raise NotImplementedError(
f"This type ({type(self.action_space)}) of game has not been"
" implemented."
)
self.log_std_min = log_std_min
self.log_std_max = log_std_max
[docs] def forward(self, obs):
if self._mixed_obs:
for key in obs.keys():
obs[key] = check(obs[key]).to(**self.tpdv)
else:
obs = check(obs).to(**self.tpdv)
features = self.base(obs)
if isinstance(self.action_space, gym.spaces.box.Box):
output = self.actor_out(features)
# print(output)
mean, log_std = output.chunk(2, dim=-1)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
raise NotImplementedError("This type of game has not been implemented.")
return mean, log_std
def _normalize(self, action) -> torch.Tensor:
"""
Normalize the action value to the action space range.
the return values of self.fcs is between -1 and 1 since we use tanh as output activation, while we want the action ranges to be (self.action_space.low, self.action_space.high).
"""
# print(self.action_space.high, self.action_space.low)
# exit()
return action
# return torch.clamp(
# action,
# torch.tensor(self.action_space.low).detach(),
# torch.tensor(self.action_space.high).detach(),
# )
# action = (action + 1) / 2 * (
# torch.tensor(self.action_space.high) - torch.tensor(self.action_space.low)
# ) + torch.tensor(self.action_space.low)
# return action
[docs] def evaluate(self, obs, deterministic=True):
mean, log_std = self.forward(obs)
if deterministic:
# action = torch.tanh(mean) # add tanh to activate
action = mean
std = torch.exp(log_std)
dist = torch.distributions.Normal(mean, std)
log_prob = dist.log_prob(action).sum(axis=-1)
log_prob -= (2 * (np.log(2) - action - F.softplus(-2 * action))).sum(
axis=-1
)
return self._normalize(action), log_prob.unsqueeze(dim=-1)
# sample action from N(mean, std) if sample is True
# obtain log_prob for policy and Q function update
# use the reparameterization trick, and perform tanh normalization
std = torch.exp(log_std)
dist = torch.distributions.Normal(mean, std)
action = dist.rsample()
log_prob = dist.log_prob(action).sum(axis=-1)
log_prob -= (2 * (np.log(2) - action - F.softplus(-2 * action))).sum(
axis=-1
) # NOTE: The correction formula from the original SAC paper (arXiv 1801.01290) appendix C
# action = torch.tanh(action) # add tanh to activate
return self._normalize(action), log_prob.unsqueeze(dim=-1)