Shortcuts

Source code for openrl.runners.common.a2c_agent

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
from typing import Optional, Type, Union

import gym
import torch

from openrl.algorithms.a2c import A2CAlgorithm
from openrl.algorithms.base_algorithm import BaseAlgorithm
from openrl.drivers.base_driver import BaseDriver
from openrl.drivers.onpolicy_driver import OnPolicyDriver as Driver
from openrl.modules.common import BaseNet
from openrl.runners.common.base_agent import SelfAgent
from openrl.runners.common.ppo_agent import PPOAgent
from openrl.utils.logger import Logger
from openrl.utils.type_aliases import MaybeCallback


[docs]class A2CAgent(PPOAgent): def __init__( self, net: Optional[Union[torch.nn.Module, BaseNet]] = None, env: Union[gym.Env, str] = None, run_dir: Optional[str] = None, env_num: Optional[int] = None, rank: int = 0, world_size: int = 1, use_wandb: bool = False, use_tensorboard: bool = False, project_name: str = "GAILAgent", ) -> None: super(A2CAgent, self).__init__( net, env, run_dir, env_num, rank, world_size, use_wandb, use_tensorboard, project_name=project_name, )
[docs] def train( self: SelfAgent, total_time_steps: int, callback: MaybeCallback = None, train_algo_class: Type[BaseAlgorithm] = A2CAlgorithm, logger: Optional[Logger] = None, driver_class: Type[BaseDriver] = Driver, ) -> None: super().train( total_time_steps, callback, train_algo_class, logger, driver_class )