Shortcuts

Source code for openrl.selfplay.opponents.network_opponent

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
from pathlib import Path
from typing import Dict, Optional, Union

from openrl.selfplay.opponents.base_opponent import BaseOpponent


[docs]class NetworkOpponent(BaseOpponent): def __init__( self, opponent_id: str, opponent_path: Union[str, Path], opponent_info: Dict[str, str], ): self.agent = None self.opponent_env = None self.deterministic_action = False super().__init__(opponent_id, opponent_path, opponent_info)
[docs] def reset(self, env=None, opponent_player: Optional[str] = None): super().reset(env, opponent_player) if self.opponent_env is not None: self.opponent_env.reset() if self.agent is not None: self.agent.reset()
def _load(self, opponent_path: Union[str, Path]): model_path = Path(opponent_path) / "module.pt" if self.agent is not None: self.agent.load(model_path) def _set_env(self, env, opponent_player: Optional[str] = None): pass
[docs] def act(self, player_name, observation, reward, termination, truncation, info): observation, termination, truncation, info = self.opponent_env.process_obs( observation, termination, truncation, info ) action, _ = self.agent.act( observation, info, deterministic=self.deterministic_action ) action = self.opponent_env.process_action(action) return action