Source code for openrl.algorithms.gail
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
from typing import Union
import torch
from openrl.algorithms.ppo import PPOAlgorithm
from openrl.datasets.expert_dataset import ExpertDataset
[docs]class GAILAlgorithm(PPOAlgorithm):
def __init__(
self,
cfg,
init_module,
agent_num: int = 1,
device: Union[str, torch.device] = "cpu",
) -> None:
super(GAILAlgorithm, self).__init__(cfg, init_module, agent_num, device)
self.train_list.append(self.train_gail)
self.gail_epoch = cfg.gail_epoch
assert cfg.expert_data is not None
expert_dataset = ExpertDataset(file_name=cfg.expert_data)
drop_last = len(expert_dataset) > self.cfg.gail_batch_size
self.dataset_loader = torch.utils.data.DataLoader(
dataset=expert_dataset,
batch_size=self.cfg.gail_batch_size,
shuffle=True,
drop_last=drop_last,
)
[docs] def train_gail(self, buffer, turn_on):
train_info = {"gail_loss": 0}
for _ in range(self.gail_epoch):
loss = self.algo_module.models["gail_discriminator"].update(
self.dataset_loader, buffer
)
train_info["gail_loss"] += loss
for k in train_info.keys():
train_info[k] /= self.gail_epoch
return train_info