Shortcuts

Source code for openrl.selfplay.wrappers.opponent_pool_wrapper

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
from typing import List, Optional

from openrl.selfplay.opponents.utils import get_opponent_from_info
from openrl.selfplay.selfplay_api.selfplay_client import SelfPlayClient
from openrl.selfplay.wrappers.base_multiplayer_wrapper import BaseMultiPlayerWrapper


[docs]class OpponentPoolWrapper(BaseMultiPlayerWrapper): def __init__(self, env, cfg, reward_class=None) -> None: super().__init__(env, cfg, reward_class) host = cfg.selfplay_api.host port = cfg.selfplay_api.port self.api_client = SelfPlayClient(f"http://{host}:{port}/selfplay/") self.opponent = None self.opponent_player = None self.lazy_load_opponent = cfg.lazy_load_opponent self.player_ids = None
[docs] def reset(self, *, seed: Optional[int] = None, **kwargs): results = super().reset(seed=seed, **kwargs) self.opponent = self.get_opponent(self.opponent_players) if self.opponent is not None: self.opponent.reset() return results
[docs] def get_opponent(self, opponent_players: List[str]): opponent_info = self.api_client.get_opponent(opponent_players) if opponent_info is not None: # currentkly, we only support 1 opponent, that means we only support games with two players opponent_info = opponent_info[0] opponent_player = opponent_players[0] opponent, is_new_opponent = get_opponent_from_info( opponent_info, current_opponent=self.opponent, lazy_load_opponent=self.lazy_load_opponent, ) if opponent is None: return self.opponent if is_new_opponent or (self.opponent_player != opponent_player): opponent.set_env(self.env, opponent_player) self.opponent_player = opponent_player return opponent else: return self.opponent
[docs] def get_opponent_action( self, player_name, observation, reward, termination, truncation, info ): if self.opponent is None: self.opponent = self.get_opponent(self.opponent_players) if self.opponent is not None: self.opponent.reset() if self.opponent is None: mask = observation["action_mask"] action_space = self.env.action_space(player_name) if isinstance(action_space, list): action = [] for space in action_space: action.append(space.sample(mask)) else: action = action_space.sample(mask) else: action = self.opponent.act( player_name, observation, reward, termination, truncation, info ) return action
[docs] def on_episode_end( self, player_name, observation, reward, termination, truncation, info ): assert "winners" in info, "winners must be in info" assert "losers" in info, "losers must be in info" assert len(info["winners"]) >= 1, "winners must be at least 1" winner_ids = [] loser_ids = [] for player in info["winners"]: if player == self.self_player: winner_id = "training_agent" else: winner_id = self.opponent.opponent_id winner_ids.append(winner_id) for player in info["losers"]: if player == self.self_player: loser_id = "training_agent" else: loser_id = self.opponent.opponent_id loser_ids.append(loser_id) assert set(winner_ids).isdisjoint(set(loser_ids)), ( "winners and losers must be disjoint, but get winners: {}, losers: {}" .format(winner_ids, loser_ids) ) battle_info = {"winner_ids": winner_ids, "loser_ids": loser_ids} self.api_client.add_battle_result(battle_info)