Source code for openrl.drivers.offpolicy_driver
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
import random
from typing import Any, Dict, Optional
import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel
from openrl.drivers.rl_driver import RLDriver
from openrl.utils.logger import Logger
from openrl.utils.util import _t2n
[docs]class OffPolicyDriver(RLDriver):
def __init__(
self,
config: Dict[str, Any],
trainer,
buffer,
agent,
rank: int = 0,
world_size: int = 1,
client=None,
logger: Optional[Logger] = None,
) -> None:
super(OffPolicyDriver, self).__init__(
config, trainer, buffer, agent, rank, world_size, client, logger
)
self.buffer_minimal_size = int(config["cfg"].buffer_size * 0.2)
self.epsilon_start = config["cfg"].epsilon_start
self.epsilon_finish = config["cfg"].epsilon_finish
self.epsilon_anneal_time = config["cfg"].epsilon_anneal_time
if self.envs.parallel_env_num > 1:
self.episode_steps = np.zeros((self.envs.parallel_env_num,))
else:
self.episode_steps = 0
self.verbose_flag = False
self.first_insert_buffer = True
def _inner_loop(
self,
) -> bool:
"""
:return: True if training should continue, False if training should stop
"""
rollout_infos = self.actor_rollout()
if self.buffer.get_buffer_size() >= 0:
train_infos = self.learner_update()
self.buffer.after_update()
else:
train_infos = {"q_loss": 0}
self.total_num_steps = (
(self.episode + 1) * self.episode_length * self.n_rollout_threads
)
if self.episode % self.log_interval == 0:
# rollout_infos can only be used when env is wrapped with VevMonitor
self.logger.log_info(rollout_infos, step=self.total_num_steps)
self.logger.log_info(train_infos, step=self.total_num_steps)
return True
[docs] def add2buffer(self, data):
(
obs,
next_obs,
rewards,
dones,
infos,
q_values,
actions,
rnn_states,
) = data
rnn_states[dones] = np.zeros(
(dones.sum(), self.recurrent_N, self.hidden_size),
dtype=np.float32,
)
# todo add image obs
if "Dict" in next_obs.__class__.__name__:
for key in next_obs.keys():
next_obs[key][dones] = np.zeros(
(dones.sum(), next_obs[key].shape[2]),
dtype=np.float32,
)
else:
next_obs[dones] = np.zeros(
(dones.sum(), next_obs.shape[2]),
dtype=np.float32,
)
masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)
masks[dones] = np.zeros((dones.sum(), 1), dtype=np.float32)
rnn_states_critic = rnn_states
action_log_probs = actions
self.buffer.insert(
obs,
next_obs,
rnn_states,
rnn_states_critic,
actions,
action_log_probs,
q_values,
rewards,
masks,
)
[docs] def actor_rollout(self):
self.trainer.prep_rollout()
import time
obs = self.buffer.data.critic_obs[0]
for step in range(self.episode_length):
q_values, actions, rnn_states = self.act(step)
extra_data = {
"q_values": q_values,
"step": step,
"buffer": self.buffer,
}
next_obs, rewards, dones, infos = self.envs.step(actions, extra_data)
if type(self.episode_steps) == int:
if not dones:
self.episode_steps += 1
else:
# print("steps: ", self.episode_steps)
self.episode_steps = 0
else:
done_index = list(np.where(dones == True)[0])
self.episode_steps += 1
for i in range(len(done_index)):
if self.episode_steps[done_index[i]] > 200:
self.verbose_flag = True
# print("steps: ", self.episode_steps[done_index[i]])
self.episode_steps[done_index[i]] = 0
# if self.verbose_flag:
# print("step: ", step,
# "state: ", self.buffer.data.get_batch_data("next_policy_obs" if step != 0 else "policy_obs", step),
# "q_values: ", q_values,
# "actions: ", actions)
# print("rewards: ", rewards)
data = (
obs,
next_obs,
rewards,
dones,
infos,
q_values,
actions,
rnn_states,
)
self.add2buffer(data)
obs = next_obs
batch_rew_infos = self.envs.batch_rewards(self.buffer)
self.first_insert_buffer = False
if self.envs.use_monitor:
statistics_info = self.envs.statistics(self.buffer)
statistics_info.update(batch_rew_infos)
return statistics_info
else:
return batch_rew_infos
[docs] @torch.no_grad()
def act(
self,
step: int,
):
self.trainer.prep_rollout()
if step != 0:
step = step - 1
(
q_values,
rnn_states,
) = self.trainer.algo_module.get_actions(
self.buffer.data.get_batch_data(
"next_policy_obs" if step != 0 else "policy_obs", step
),
np.concatenate(self.buffer.data.rnn_states[step]),
np.concatenate(self.buffer.data.masks[step]),
)
q_values = np.array(np.split(_t2n(q_values), self.n_rollout_threads))
rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads))
epsilon = np.min(
(
self.epsilon_finish
+ (self.epsilon_start - self.epsilon_finish)
/ self.epsilon_anneal_time
* (self.episode * self.episode_length + step),
self.epsilon_start,
)
)
actions = np.expand_dims(q_values.argmax(axis=-1), axis=-1)
if random.random() >= epsilon or self.first_insert_buffer:
actions = np.random.randint(
low=0, high=self.envs.action_space.n, size=actions.shape
)
return (
q_values,
actions,
rnn_states,
)