Shortcuts

openrl.algorithms package

Submodules

openrl.algorithms.a2c module

class openrl.algorithms.a2c.A2CAlgorithm(cfg, init_module, agent_num: int = 1, device: Union[str, torch.device] = 'cpu')[source]

Bases: openrl.algorithms.ppo.PPOAlgorithm

prepare_loss(critic_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, masks_batch, action_masks_batch, old_action_log_probs_batch, adv_targ, value_preds_batch, return_batch, active_masks_batch, turn_on)[source]
train(buffer, turn_on: bool = True)[source]

openrl.algorithms.base_algorithm module

class openrl.algorithms.base_algorithm.BaseAlgorithm(cfg, init_module, agent_num: int, device=device(type='cpu'))[source]

Bases: abc.ABC

prep_rollout()[source]
prep_training()[source]
abstract train(buffer, turn_on=True)[source]

openrl.algorithms.behavior_cloning module

class openrl.algorithms.behavior_cloning.BCAlgorithm(cfg, init_module, agent_num: int = 1, device: Union[str, torch.device] = 'cpu')[source]

Bases: openrl.algorithms.base_algorithm.BaseAlgorithm

bc_update(sample, turn_on=True)[source]
construct_loss_list(policy_loss, dist_entropy, value_loss, turn_on)[source]
get_data_generator(buffer)[source]
prepare_loss(critic_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, masks_batch, action_masks_batch, old_action_log_probs_batch, adv_targ, value_preds_batch, return_batch, active_masks_batch, turn_on)[source]
to_single_np(input)[source]
train(buffer, turn_on=True)[source]
train_bc(buffer, turn_on)[source]

openrl.algorithms.ddpg module

class openrl.algorithms.ddpg.DDPGAlgorithm(cfg, init_module, agent_num: int = 1, device: Union[str, torch.device] = 'cpu')[source]

Bases: openrl.algorithms.base_algorithm.BaseAlgorithm

cal_value_loss(value_normalizer, values, value_preds_batch, return_batch, active_masks_batch)[source]
ddpg_update(sample, turn_on=True)[source]
prepare_actor_loss(obs_batch, next_obs_batch, rnn_states_batch, actions_batch, masks_batch, action_masks_batch, value_preds_batch, rewards_batch, active_masks_batch, turn_on)[source]
prepare_critic_loss(obs_batch, next_obs_batch, rnn_states_batch, actions_batch, masks_batch, next_masks_batch, action_masks_batch, value_preds_batch, rewards_batch, active_masks_batch, turn_on)[source]
to_single_np(input)[source]
train(buffer, turn_on=True)[source]

openrl.algorithms.dqn module

class openrl.algorithms.dqn.DQNAlgorithm(cfg, init_module, agent_num: int = 1, device: Union[str, torch.device] = 'cpu')[source]

Bases: openrl.algorithms.base_algorithm.BaseAlgorithm

dqn_update(sample, turn_on=True)[source]
prepare_loss(obs_batch, next_obs_batch, rnn_states_batch, actions_batch, masks_batch, next_masks_batch, action_masks_batch, value_preds_batch, rewards_batch, active_masks_batch, turn_on)[source]
to_single_np(input)[source]
train(buffer, turn_on=True)[source]

openrl.algorithms.gail module

class openrl.algorithms.gail.GAILAlgorithm(cfg, init_module, agent_num: int = 1, device: Union[str, torch.device] = 'cpu')[source]

Bases: openrl.algorithms.ppo.PPOAlgorithm

train_gail(buffer, turn_on)[source]

openrl.algorithms.mat module

class openrl.algorithms.mat.MATAlgorithm(cfg, init_module, agent_num: int = 1, device: Union[str, torch.device] = 'cpu')[source]

Bases: openrl.algorithms.ppo.PPOAlgorithm

construct_loss_list(policy_loss, dist_entropy, value_loss, turn_on)[source]
get_data_generator(buffer, advantages)[source]

openrl.algorithms.ppo module

class openrl.algorithms.ppo.PPOAlgorithm(cfg, init_module, agent_num: int = 1, device: Union[str, torch.device] = 'cpu')[source]

Bases: openrl.algorithms.base_algorithm.BaseAlgorithm

cal_value_loss(value_normalizer, values, value_preds_batch, return_batch, active_masks_batch)[source]
construct_loss_list(policy_loss, dist_entropy, value_loss, turn_on)[source]
get_data_generator(buffer, advantages)[source]
ppo_update(sample, turn_on=True)[source]
prepare_loss(critic_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, masks_batch, action_masks_batch, old_action_log_probs_batch, adv_targ, value_preds_batch, return_batch, active_masks_batch, turn_on)[source]
to_single_np(input)[source]
train(buffer, turn_on=True)[source]
train_ppo(buffer, turn_on)[source]

openrl.algorithms.sac module

class openrl.algorithms.sac.SACAlgorithm(cfg, init_module, agent_num: int = 1, device: Union[str, torch.device] = 'cpu')[source]

Bases: openrl.algorithms.base_algorithm.BaseAlgorithm

cal_value_loss(value_normalizer, values, value_preds_batch, return_batch, active_masks_batch)[source]
prepare_actor_loss(obs_batch, next_obs_batch, rnn_states_batch, actions_batch, masks_batch, action_masks_batch, value_preds_batch, rewards_batch, active_masks_batch, turn_on)[source]
prepare_alpha_loss(log_prob)[source]
prepare_critic_loss(obs_batch, next_obs_batch, rnn_states_batch, actions_batch, masks_batch, next_masks_batch, action_masks_batch, value_preds_batch, rewards_batch, active_masks_batch, turn_on)[source]
sac_update(sample, turn_on=True)[source]
to_single_np(input)[source]
train(buffer, turn_on=True)[source]

openrl.algorithms.vdn module

class openrl.algorithms.vdn.VDNAlgorithm(cfg, init_module, agent_num: int = 1, device: Union[str, torch.device] = 'cpu')[source]

Bases: openrl.algorithms.base_algorithm.BaseAlgorithm

cal_value_loss(value_normalizer, values, value_preds_batch, return_batch, active_masks_batch)[source]
prepare_loss(obs_batch, next_obs_batch, rnn_states_batch, actions_batch, masks_batch, next_masks_batch, action_masks_batch, value_preds_batch, rewards_batch, active_masks_batch, turn_on)[source]
to_single_np(input)[source]
train(buffer, turn_on=True)[source]
vdn_update(sample, turn_on=True)[source]

Module contents