openrl.modules.networks.utils.vdn 源代码 import torch import torch.nn as nn [文档]class VDNBase(nn.Module): def __init__(self): super(VDNBase, self).__init__() [文档] def forward(self, agent_qs): return torch.sum(agent_qs, dim=1, keepdim=True)