Source code for openrl.modules.rl_module
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, Optional, Union
import torch
from gym import spaces
from openrl.modules.base_module import BaseModule
from openrl.modules.model_config import ModelTrainConfig
[docs]class RLModule(BaseModule):
def __init__(
self,
cfg,
act_space: spaces.Box,
rank: int = 0,
world_size: int = 1,
device: Union[str, torch.device] = "cpu",
model_configs: Optional[Dict[str, ModelTrainConfig]] = None,
) -> None:
super(RLModule, self).__init__(cfg)
if isinstance(device, str):
device = torch.device(device)
self.cfg = cfg
self.device = device
self.lr = cfg.lr
self.critic_lr = cfg.critic_lr
self.opti_eps = cfg.opti_eps
self.weight_decay = cfg.weight_decay
self.load_optimizer = cfg.load_optimizer
self.act_space = act_space
self.program_type = cfg.program_type
self.rank = rank
self.world_size = world_size
self.use_deepspeed = cfg.use_deepspeed
use_half_actor = self.program_type == "actor" and cfg.use_half_actor
if model_configs is None:
model_configs = self.get_model_configs(cfg)
for model_key in model_configs:
model_cg = model_configs[model_key]
model = model_cg["model"](
cfg=cfg,
input_space=model_cg["input_space"],
action_space=act_space,
device=device,
use_half=use_half_actor,
extra_args=model_cg["extra_args"] if "extra_args" in model_cg else None,
)
if self.program_type == "actor":
continue
if not self.use_deepspeed:
optimizer = torch.optim.Adam(
model.parameters(),
lr=model_cg["lr"],
eps=cfg.opti_eps,
weight_decay=cfg.weight_decay,
)
self.models.update({model_key: model})
self.optimizers.update({model_key: optimizer})
else:
import json
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from transformers import get_constant_schedule
self.use_fp16 = cfg.use_fp16
self.use_offload = cfg.use_offload
# Check for inconsistencies in configuration files
assert not (self.use_fp16 and not self.use_deepspeed)
assert not (self.use_offload and not self.use_deepspeed)
assert cfg.deepspeed_config is not None
with open(cfg.deepspeed_config) as file:
ds_config = json.load(file)
if "fp16" in ds_config:
assert ds_config["fp16"]["enabled"] == self.use_fp16
AdamOptimizer = DeepSpeedCPUAdam if self.use_offload else FusedAdam
optim_params = filter(lambda p: p.requires_grad, model.parameters())
optim = AdamOptimizer(
optim_params, lr=model_cg["lr"], betas=(0.9, 0.95)
)
# LR Scheduler
lr_scheduler = get_constant_schedule(
optimizer=optim,
)
engine, *_ = deepspeed.initialize(
args=cfg,
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
)
self.models.update({model_key: engine})
self.optimizers.update({model_key: engine})
if cfg.use_amp:
self.scaler = torch.cuda.amp.GradScaler()
else:
self.scaler = None
[docs] def load_policy(self, model_path: str) -> None:
model_path = Path(model_path)
assert (
model_path.exists()
), "can not find policy weight file to load: {}".format(model_path)
state_dict = torch.load(str(model_path), map_location=self.device)
if "policy" in self.models:
self.models["policy"].load_state_dict(state_dict)
else:
self.models["model"].load_state_dict(state_dict)
del state_dict
[docs] def restore(self, model_dir: str) -> None:
model_dir = Path(model_dir)
assert model_dir.exists(), "can not find model directory to restore: {}".format(
model_dir
)
for model_name in self.models:
state_dict = torch.load(
str(model_dir) + "/{}.pt".format(model_name), map_location=self.device
)
self.models[model_name].load_state_dict(state_dict)
del state_dict
if self.load_optimizer:
if Path(str(model_dir) + "/actor_optimizer.pt").exists():
for optimizer_name in self.optimizers:
state_dict = torch.load(
str(model_dir) + "/{}_optimizer.pt".format(optimizer_name),
map_location=self.device,
)
self.optimizers[optimizer_name].load_state_dict(state_dict)
del state_dict
else:
print("can't find optimizer to restore")
# TODO
# optimizer.load_state_dict(resume_state['optimizer'])