【MADRL】基于MADRL的单调价值函数分解(QMIX)算法
本篇文章是博主强化学习RL领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在强化学习专栏:
强化学习(5)---《【MADRL】基于MADRL的单调价值函数分解(QMIX)算法》
【MADRL】基于MADRL的单调价值函数分解(QMIX)算法
目录
0. 前言
1. 背景与挑战
2. QMIX算法架构
3. 算法训练过程
4. QMIX的优势
5. QMIX的应用
6. 局限性与改进
[Python] QMIX实现(可移植)
0. 前言
基于MADRL的单调价值函数分解(Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning)QMIX 是一种用于多智能体强化学习的算法,特别适用于需要协作的多智能体环境,如分布式控制、团队作战等场景。QMIX 算法由 Rashid 等人在 2018 年提出,其核心思想是通过一种混合网络(Mixing Network)来对各个智能体的局部 Q 值进行非线性组合,从而得到全局 Q 值。
算法原文:Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning
算法程序例程
个人成功移植的算法程序和注释在下文

1. 背景与挑战
在多智能体强化学习中,每个智能体都需要基于自身的观测和经验来学习策略。在一个协作环境中,多个智能体的决策往往相互影响,因此仅考虑单个智能体的 Q 值并不足够。直接对整个系统的 Q 值进行建模在计算上是不可行的,因为状态和动作空间会随着智能体数量呈指数增长。
2. QMIX算法架构
QMIX算法由以下几个核心组件组成:
2.1 局部Q网络(Individual Q Networks)
- 每个智能体都有一个独立的局部Q网络,该网络输入智能体的局部观测
和动作
,输出该智能体的局部Q值
。
- 局部Q网络可以使用任何深度神经网络结构来表示,如卷积神经网络(CNN)或前馈神经网络(FNN),根据任务的具体需求进行选择。
2.2 混合网络(Mixing Network)
- 混合网络的作用是将各个智能体的局部Q值进行组合,生成全局Q值
。该网络的结构是一个完全连接的神经网络,由一组参数化的权重和偏置决定。
- 混合网络的输入是所有智能体的局部Q值
以及全局状态 (s)(在训练过程中使用)。输出是全局Q值
。
- 单调性约束:混合网络的设计要求全局Q值
对于各个局部Q值
是单调非减函数。这意味着,任意一个局部Q值的增加不会导致全局Q值的减小。该约束通过使用非负的权重来实现。
2.3 全局Q值的计算
混合网络根据以下公式计算全局Q值:
其中,(f) 表示混合网络的映射函数,(n) 是智能体的数量,(s) 是全局状态信息。
3. 算法训练过程
QMIX的训练基于Q-learning的框架,具体步骤如下:
3.1 经验采集(Experience Collection)
在每个时间步,所有智能体根据当前策略选择动作,并与环境交互,收集经验样本,其中
表示所有智能体的观测集合,
表示所有智能体的动作集合,(r) 是全局奖励,(s') 是下一个状态。
3.2 目标Q值计算(Target Q Calculation)
计算下一个状态 (s') 下的目标Q值:
其中, 是折扣因子,
是目标网络的参数(使用延迟更新策略)。
3.3 损失函数与优化(Loss Function and Optimization)
通过最小化TD误差来更新混合网络和局部Q网络的参数:
使用反向传播和随机梯度下降(SGD)来更新网络参数。
3.4 目标网络的更新
为了稳定训练过程,QMIX使用了目标网络。目标网络的参数以较低的频率从当前网络的参数
复制而来。
4. QMIX的优势
- 协作性:通过全局Q值的优化,QMIX能够有效捕捉智能体之间的协作关系。
- 可扩展性:由于混合网络的设计,QMIX可以扩展到更多智能体的环境中,而不会因为联合动作空间的指数级增长而受到影响。
- 灵活性:通过非线性混合网络,QMIX能够处理复杂的协作任务,而不仅限于线性组合策略。
5. QMIX的应用
- 分布式机器人控制:在多个机器人需要协作完成任务的场景下,QMIX可以学习到有效的协作策略。
- 团队游戏AI:在需要团队协作的游戏中,QMIX被广泛应用于训练复杂的多智能体AI。
- 资源分配与管理:在智能电网或多无人机系统中,QMIX能够有效处理多智能体之间的资源协调问题。
6. 局限性与改进
- 策略的表达能力受限:由于单调性约束,QMIX可能无法表达某些复杂的非线性策略。
- 样本效率:在高维环境中,QMIX对样本的需求较大,训练时间较长。
- 改进方法:后续的算法如QTRAN、QPLEX等在不同程度上尝试解决这些局限性,进一步提升了多智能体强化学习的性能。
[Python] QMIX实现(可移植)
若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱,以便于及时分享给您(私信难以及时回复)。
主函数文件:
"""
@content: QMIX
@author: 不去幼儿园
@Timeline: 2024.08.21
"""
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from env_base import Env # @移植事项:导入环境
import argparse
from replay_buffer import ReplayBuffer # @移植事项:导入其他类
from qmix_smac import QMIX_SMAC
from normalization import Normalizationclass Runner_QMIX_SMAC:def __init__(self, args, env_name, number, seed):self.args = argsself.env_name = env_nameself.number = numberself.seed = seed# Set random seednp.random.seed(self.seed)torch.manual_seed(self.seed)# Create env# self.env = StarCraft2Env(map_name=self.env_name, seed=self.seed)"""@移植事项1.环境声明2.环境参数设置:注意相关参数的格式"""self.env = Env() # @移植事项:环境声明# self.env_info = self.env.get_env_info()self.args.N = 3 # The number of agentsself.args.obs_dim = 15 # The dimensions of an agent's observation spaceself.args.state_dim = 100+9*3+2*3 # The dimensions of global state spaceself.args.action_dim = 9 # The dimensions of an agent's action spaceself.args.episode_limit = 50 # Maximum number of steps per episodprint("number of agents={}".format(self.args.N))print("obs_dim={}".format(self.args.obs_dim))print("state_dim={}".format(self.args.state_dim))print("action_dim={}".format(self.args.action_dim))print("episode_limit={}".format(self.args.episode_limit))# Create N agentsself.agent_n = QMIX_SMAC(self.args)self.replay_buffer = ReplayBuffer(self.args)# Create a tensorboardself.writer = SummaryWriter(log_dir='./runs/{}/{}_env_{}_number_{}_seed_{}'.format(self.args.algorithm, self.args.algorithm, self.env_name, self.number, self.seed))self.epsilon = self.args.epsilon # Initialize the epsilonself.win_rates = [] # Record the win ratesself.total_steps = 0if self.args.use_reward_norm:print("------use reward norm------")self.reward_norm = Normalization(shape=1)def run(self, ):evaluate_num = -1 # Record the number of evaluationswhile self.total_steps < self.args.max_train_steps:if self.total_steps // self.args.evaluate_freq > evaluate_num:self.evaluate_policy() # Evaluate the policy every 'evaluate_freq' stepsevaluate_num += 1_, _, episode_steps = self.run_episode_smac(evaluate=False) # Run an episodeself.total_steps += episode_stepsif self.replay_buffer.current_size >= self.args.batch_size:self.agent_n.train(self.replay_buffer, self.total_steps) # Trainingself.evaluate_policy()# self.env.close()def evaluate_policy(self, ):win_times = 0evaluate_reward = 0goal_num_buffer__ = []for _ in range(self.args.evaluate_times):win_tag, episode_reward, _ = self.run_episode_smac(evaluate=True)"""获取其他状态数据"""goal_num_buffer_ = self.env.get_state_data() # @移植事项:其他状态获取goal_num_buffer_ = np.array(goal_num_buffer_)goal_num_buffer__.append(goal_num_buffer_)if win_tag:win_times += 1evaluate_reward += episode_rewardgoal_num_buffer = np.sum(goal_num_buffer__[:], axis=0) / self.args.evaluate_timeslog_flag = ["state/target_num", "state/target_num", "state/crash_num","state/ratio"]for i in range(4):goal_num = goal_num_buffer[i]goal_num = {log_flag[i]: goal_num}log_state(name=log_flag[i], state=goal_num, step=self.total_steps)win_rate = win_times / self.args.evaluate_timesevaluate_reward = evaluate_reward / self.args.evaluate_timesreward_total = {"state/reward_total": evaluate_reward}log_state(name="state/reward_total", state=reward_total, step=self.total_steps)self.win_rates.append(win_rate)print("total_steps:{}\tepisode:{}\tevaluate_reward:{:.3f}\t""target_num:{:.3f}\ttarget_num:{:.3f}\tcrash_num:{:.3f}\tratio:{:.3f}".format(self.total_steps, int(self.total_steps / 1250 + 1), evaluate_reward,goal_num_buffer[0], goal_num_buffer[1], goal_num_buffer[2], goal_num_buffer[3]))# self.writer.add_scalar('win_rate_{}'.format(self.env_name), win_rate, global_step=self.total_steps)# Save the win ratesnp.save('./data_train/{}_env_{}_number_{}_seed_{}.npy'.format(self.args.algorithm, self.env_name, self.number, self.seed), np.array(self.win_rates))def run_episode_smac(self, evaluate=False):win_tag = Falseepisode_reward = 0"""移植事项:环境运行1.环境重置函数设置2.环境状态返回函数设置:注意格式3.环境下一步更新:注意返回值"""self.env.reset() # @移植事项:环境重置函数if self.args.use_rnn: # If use RNN, before the beginning of each episode,reset the rnn_hidden of the Q network.self.agent_n.eval_Q_net.rnn_hidden = Nonelast_onehot_a_n = np.zeros((self.args.N, self.args.action_dim)) # Last actions of N agents(one-hot)for episode_step in range(self.args.episode_limit):obs_n = self.env.get_obs() # obs_n.shape=(N,obs_dim) # @移植事项:观测状态获取s = self.env.get_state() # s.shape=(state_dim,) # @移植事项:状态获取# avail_a_n = self.env.get_avail_actions() # Get available actions of N agents, avail_a_n.shape=(N,action_dim)avail_a_n = [[1] * 9 for _ in range(3)]epsilon = 0 if evaluate else self.epsilona_n = self.agent_n.choose_action(obs_n, last_onehot_a_n, avail_a_n, epsilon)last_onehot_a_n = np.eye(self.args.action_dim)[a_n] # Convert actions to one-hot vectors_, r_, done_, info = self.env.step(a_n) # @移植事项:环境下一步更新done = done_[0]r = sum(list(np.array(r_).flatten()))win_tag = True if done and 'battle_won' in info and info['battle_won'] else Falseepisode_reward += rif not evaluate:if self.args.use_reward_norm:r = self.reward_norm(r)""""When dead or win or reaching the episode_limit, done will be Ture, we need to distinguish them;dw means dead or win,there is no next state s';but when reaching the max_episode_steps,there is a next state s' actually."""if done and episode_step + 1 != self.args.episode_limit:dw = Trueelse:dw = False# Store the transitionself.replay_buffer.store_transition(episode_step, obs_n, s, avail_a_n, last_onehot_a_n, a_n, r, dw)# Decay the epsilonself.epsilon = self.epsilon - self.args.epsilon_decay if self.epsilon - self.args.epsilon_decay > self.args.epsilon_min else self.args.epsilon_minif done:breakif not evaluate:# An episode is over, store obs_n, s and avail_a_n in the last stepobs_n = self.env.get_obs() # @移植事项s = self.env.get_state() # @移植事项# avail_a_n = self.env.get_avail_actions()avail_a_n = [[1] * 9 for _ in range(3)]self.replay_buffer.store_last_step(episode_step + 1, obs_n, s, avail_a_n)return win_tag, episode_reward, episode_step + 1# 运行结果展示
from tensorboardX import SummaryWriter
writer = SummaryWriter()
def log_state(name, state, step):writer.add_scalars(name, state, step)if __name__ == '__main__':parser = argparse.ArgumentParser("Hyperparameter Setting for QMIX and VDN in SMAC environment")parser.add_argument("--max_train_steps", type=int, default=int(1e6), help=" Maximum number of training steps")parser.add_argument("--evaluate_freq", type=float, default=1250, help="Evaluate the policy every 'evaluate_freq' steps")parser.add_argument("--evaluate_times", type=float, default=5, help="Evaluate times")parser.add_argument("--save_freq", type=int, default=int(1e5), help="Save frequency")parser.add_argument("--algorithm", type=str, default="QMIX", help="QMIX or VDN")parser.add_argument("--epsilon", type=float, default=1.0, help="Initial epsilon")parser.add_argument("--epsilon_decay_steps", type=float, default=50000, help="How many steps before the epsilon decays to the minimum")parser.add_argument("--epsilon_min", type=float, default=0.05, help="Minimum epsilon")parser.add_argument("--buffer_size", type=int, default=5000, help="The capacity of the replay buffer")parser.add_argument("--batch_size", type=int, default=32, help="Batch size (the number of episodes)")parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor")parser.add_argument("--qmix_hidden_dim", type=int, default=32, help="The dimension of the hidden layer of the QMIX network")parser.add_argument("--hyper_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of the hyper-network")parser.add_argument("--hyper_layers_num", type=int, default=1, help="The number of layers of hyper-network")parser.add_argument("--rnn_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of RNN")parser.add_argument("--mlp_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of MLP")parser.add_argument("--use_rnn", type=bool, default=True, help="Whether to use RNN")parser.add_argument("--use_orthogonal_init", type=bool, default=True, help="Orthogonal initialization")parser.add_argument("--use_grad_clip", type=bool, default=True, help="Gradient clip")parser.add_argument("--use_lr_decay", type=bool, default=False, help="use lr decay")parser.add_argument("--use_RMS", type=bool, default=False, help="Whether to use RMS,if False, we will use Adam")parser.add_argument("--add_last_action", type=bool, default=True, help="Whether to add last actions into the observation")parser.add_argument("--add_agent_id", type=bool, default=True, help="Whether to add agent id into the observation")parser.add_argument("--use_double_q", type=bool, default=True, help="Whether to use double q-learning")parser.add_argument("--use_reward_norm", type=bool, default=False, help="Whether to use reward normalization")parser.add_argument("--use_hard_update", type=bool, default=True, help="Whether to use hard update")parser.add_argument("--target_update_freq", type=int, default=200, help="Update frequency of the target network")parser.add_argument("--tau", type=int, default=0.005, help="If use soft update")args = parser.parse_args()args.epsilon_decay = (args.epsilon - args.epsilon_min) / args.epsilon_decay_stepsenv_names = ['3m', '8m', '2s3z']env_index = 0runner = Runner_QMIX_SMAC(args, env_name=env_names[env_index], number=1, seed=0)runner.run()
from replay_buffer import ReplayBuffer
replay_buffer.py文件
import numpy as np
import torch
import copyclass ReplayBuffer:def __init__(self, args):self.N = args.Nself.obs_dim = args.obs_dimself.state_dim = args.state_dimself.action_dim = args.action_dimself.episode_limit = args.episode_limitself.buffer_size = args.buffer_sizeself.batch_size = args.batch_sizeself.episode_num = 0self.current_size = 0self.buffer = {'obs_n': np.zeros([self.buffer_size, self.episode_limit + 1, self.N, self.obs_dim]),'s': np.zeros([self.buffer_size, self.episode_limit + 1, self.state_dim]),'avail_a_n': np.ones([self.buffer_size, self.episode_limit + 1, self.N, self.action_dim]), # Note: We use 'np.ones' to initialize 'avail_a_n''last_onehot_a_n': np.zeros([self.buffer_size, self.episode_limit + 1, self.N, self.action_dim]),'a_n': np.zeros([self.buffer_size, self.episode_limit, self.N]),'r': np.zeros([self.buffer_size, self.episode_limit, 1]),'dw': np.ones([self.buffer_size, self.episode_limit, 1]), # Note: We use 'np.ones' to initialize 'dw''active': np.zeros([self.buffer_size, self.episode_limit, 1])}self.episode_len = np.zeros(self.buffer_size)def store_transition(self, episode_step, obs_n, s, avail_a_n, last_onehot_a_n, a_n, r, dw):self.buffer['obs_n'][self.episode_num][episode_step] = obs_nself.buffer['s'][self.episode_num][episode_step] = sself.buffer['avail_a_n'][self.episode_num][episode_step] = avail_a_nself.buffer['last_onehot_a_n'][self.episode_num][episode_step + 1] = last_onehot_a_nself.buffer['a_n'][self.episode_num][episode_step] = a_nself.buffer['r'][self.episode_num][episode_step] = rself.buffer['dw'][self.episode_num][episode_step] = dwself.buffer['active'][self.episode_num][episode_step] = 1.0def store_last_step(self, episode_step, obs_n, s, avail_a_n):self.buffer['obs_n'][self.episode_num][episode_step] = obs_nself.buffer['s'][self.episode_num][episode_step] = sself.buffer['avail_a_n'][self.episode_num][episode_step] = avail_a_nself.episode_len[self.episode_num] = episode_step # Record the length of this episodeself.episode_num = (self.episode_num + 1) % self.buffer_sizeself.current_size = min(self.current_size + 1, self.buffer_size)def sample(self):# Randomly samplingindex = np.random.choice(self.current_size, size=self.batch_size, replace=False)max_episode_len = int(np.max(self.episode_len[index]))batch = {}for key in self.buffer.keys():if key == 'obs_n' or key == 's' or key == 'avail_a_n' or key == 'last_onehot_a_n':batch[key] = torch.tensor(self.buffer[key][index, :max_episode_len + 1], dtype=torch.float32)elif key == 'a_n':batch[key] = torch.tensor(self.buffer[key][index, :max_episode_len], dtype=torch.long)else:batch[key] = torch.tensor(self.buffer[key][index, :max_episode_len], dtype=torch.float32)return batch, max_episode_len
from qmix_smac import QMIX_SMAC
qmix_smac .py文件
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mix_net import QMIX_Net, VDN_Net# orthogonal initialization
def orthogonal_init(layer, gain=1.0):for name, param in layer.named_parameters():if 'bias' in name:nn.init.constant_(param, 0)elif 'weight' in name:nn.init.orthogonal_(param, gain=gain)class Q_network_RNN(nn.Module):def __init__(self, args, input_dim):super(Q_network_RNN, self).__init__()self.rnn_hidden = Noneself.fc1 = nn.Linear(input_dim, args.rnn_hidden_dim)self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)self.fc2 = nn.Linear(args.rnn_hidden_dim, args.action_dim)if args.use_orthogonal_init:print("------use_orthogonal_init------")orthogonal_init(self.fc1)orthogonal_init(self.rnn)orthogonal_init(self.fc2)def forward(self, inputs):# When 'choose_action', inputs.shape(N,input_dim)# When 'train', inputs.shape(bach_size*N,input_dim)x = F.relu(self.fc1(inputs))self.rnn_hidden = self.rnn(x, self.rnn_hidden)Q = self.fc2(self.rnn_hidden)return Qclass Q_network_MLP(nn.Module):def __init__(self, args, input_dim):super(Q_network_MLP, self).__init__()self.rnn_hidden = Noneself.fc1 = nn.Linear(input_dim, args.mlp_hidden_dim)self.fc2 = nn.Linear(args.mlp_hidden_dim, args.mlp_hidden_dim)self.fc3 = nn.Linear(args.mlp_hidden_dim, args.action_dim)if args.use_orthogonal_init:print("------use_orthogonal_init------")orthogonal_init(self.fc1)orthogonal_init(self.fc2)orthogonal_init(self.fc3)def forward(self, inputs):# When 'choose_action', inputs.shape(N,input_dim)# When 'train', inputs.shape(bach_size,max_episode_len,N,input_dim)x = F.relu(self.fc1(inputs))x = F.relu(self.fc2(x))Q = self.fc3(x)return Qclass QMIX_SMAC(object):def __init__(self, args):self.N = args.Nself.action_dim = args.action_dimself.obs_dim = args.obs_dimself.state_dim = args.state_dimself.add_last_action = args.add_last_actionself.add_agent_id = args.add_agent_idself.max_train_steps=args.max_train_stepsself.lr = args.lrself.gamma = args.gammaself.use_grad_clip = args.use_grad_clipself.batch_size = args.batch_size # 这里的batch_size代表有多少个episodeself.target_update_freq = args.target_update_freqself.tau = args.tauself.use_hard_update = args.use_hard_updateself.use_rnn = args.use_rnnself.algorithm = args.algorithmself.use_double_q = args.use_double_qself.use_RMS = args.use_RMSself.use_lr_decay = args.use_lr_decay# Compute the input dimensionself.input_dim = self.obs_dimif self.add_last_action:print("------add last action------")self.input_dim += self.action_dimif self.add_agent_id:print("------add agent id------")self.input_dim += self.Nif self.use_rnn:print("------use RNN------")self.eval_Q_net = Q_network_RNN(args, self.input_dim)self.target_Q_net = Q_network_RNN(args, self.input_dim)else:print("------use MLP------")self.eval_Q_net = Q_network_MLP(args, self.input_dim)self.target_Q_net = Q_network_MLP(args, self.input_dim)self.target_Q_net.load_state_dict(self.eval_Q_net.state_dict())if self.algorithm == "QMIX":print("------algorithm: QMIX------")self.eval_mix_net = QMIX_Net(args)self.target_mix_net = QMIX_Net(args)elif self.algorithm == "VDN":print("------algorithm: VDN------")self.eval_mix_net = VDN_Net()self.target_mix_net = VDN_Net()else:print("wrong!!!")self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict())self.eval_parameters = list(self.eval_mix_net.parameters()) + list(self.eval_Q_net.parameters())if self.use_RMS:print("------optimizer: RMSprop------")self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=self.lr)else:print("------optimizer: Adam------")self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr)self.train_step = 0def choose_action(self, obs_n, last_onehot_a_n, avail_a_n, epsilon):with torch.no_grad():if np.random.uniform() < epsilon: # epsilon-greedy# Only available actions can be chosena_n = [np.random.choice(np.nonzero(avail_a)[0]) for avail_a in avail_a_n]else:inputs = []obs_n = torch.tensor(obs_n, dtype=torch.float32) # obs_n.shape=(N,obs_dim)inputs.append(obs_n)if self.add_last_action:last_a_n = torch.tensor(last_onehot_a_n, dtype=torch.float32)inputs.append(last_a_n)if self.add_agent_id:inputs.append(torch.eye(self.N))inputs = torch.cat([x for x in inputs], dim=-1) # inputs.shape=(N,inputs_dim)q_value = self.eval_Q_net(inputs)avail_a_n = torch.tensor(avail_a_n, dtype=torch.float32) # avail_a_n.shape=(N, action_dim)q_value[avail_a_n == 0] = -float('inf') # Mask the unavailable actionsa_n = q_value.argmax(dim=-1).numpy()return a_ndef train(self, replay_buffer, total_steps):batch, max_episode_len = replay_buffer.sample() # Get training dataself.train_step += 1inputs = self.get_inputs(batch, max_episode_len) # inputs.shape=(bach_size,max_episode_len+1,N,input_dim)if self.use_rnn:self.eval_Q_net.rnn_hidden = Noneself.target_Q_net.rnn_hidden = Noneq_evals, q_targets = [], []for t in range(max_episode_len): # t=0,1,2,...(episode_len-1)q_eval = self.eval_Q_net(inputs[:, t].reshape(-1, self.input_dim)) # q_eval.shape=(batch_size*N,action_dim)q_target = self.target_Q_net(inputs[:, t + 1].reshape(-1, self.input_dim))q_evals.append(q_eval.reshape(self.batch_size, self.N, -1)) # q_eval.shape=(batch_size,N,action_dim)q_targets.append(q_target.reshape(self.batch_size, self.N, -1))# Stack them according to the time (dim=1)q_evals = torch.stack(q_evals, dim=1) # q_evals.shape=(batch_size,max_episode_len,N,action_dim)q_targets = torch.stack(q_targets, dim=1)else:q_evals = self.eval_Q_net(inputs[:, :-1]) # q_evals.shape=(batch_size,max_episode_len,N,action_dim)q_targets = self.target_Q_net(inputs[:, 1:])with torch.no_grad():if self.use_double_q: # If use double q-learning, we use eval_net to choose actions,and use target_net to compute q_targetq_eval_last = self.eval_Q_net(inputs[:, -1].reshape(-1, self.input_dim)).reshape(self.batch_size, 1, self.N, -1)q_evals_next = torch.cat([q_evals[:, 1:], q_eval_last], dim=1) # q_evals_next.shape=(batch_size,max_episode_len,N,action_dim)q_evals_next[batch['avail_a_n'][:, 1:] == 0] = -999999a_argmax = torch.argmax(q_evals_next, dim=-1, keepdim=True) # a_max.shape=(batch_size,max_episode_len, N, 1)q_targets = torch.gather(q_targets, dim=-1, index=a_argmax).squeeze(-1) # q_targets.shape=(batch_size, max_episode_len, N)else:q_targets[batch['avail_a_n'][:, 1:] == 0] = -999999q_targets = q_targets.max(dim=-1)[0] # q_targets.shape=(batch_size, max_episode_len, N)# batch['a_n'].shape(batch_size,max_episode_len, N)q_evals = torch.gather(q_evals, dim=-1, index=batch['a_n'].unsqueeze(-1)).squeeze(-1) # q_evals.shape(batch_size, max_episode_len, N)# Compute q_total using QMIX or VDN, q_total.shape=(batch_size, max_episode_len, 1)if self.algorithm == "QMIX":q_total_eval = self.eval_mix_net(q_evals, batch['s'][:, :-1])q_total_target = self.target_mix_net(q_targets, batch['s'][:, 1:])else:q_total_eval = self.eval_mix_net(q_evals)q_total_target = self.target_mix_net(q_targets)# targets.shape=(batch_size,max_episode_len,1)targets = batch['r'] + self.gamma * (1 - batch['dw']) * q_total_targettd_error = (q_total_eval - targets.detach())mask_td_error = td_error * batch['active']loss = (mask_td_error ** 2).sum() / batch['active'].sum()self.optimizer.zero_grad()loss.backward()if self.use_grad_clip:torch.nn.utils.clip_grad_norm_(self.eval_parameters, 10)self.optimizer.step()if self.use_hard_update:# hard updateif self.train_step % self.target_update_freq == 0:self.target_Q_net.load_state_dict(self.eval_Q_net.state_dict())self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict())else:# Softly update the target networksfor param, target_param in zip(self.eval_Q_net.parameters(), self.target_Q_net.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)for param, target_param in zip(self.eval_mix_net.parameters(), self.target_mix_net.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)if self.use_lr_decay:self.lr_decay(total_steps)def lr_decay(self, total_steps): # Learning rate Decaylr_now = self.lr * (1 - total_steps / self.max_train_steps)for p in self.optimizer.param_groups:p['lr'] = lr_nowdef get_inputs(self, batch, max_episode_len):inputs = []inputs.append(batch['obs_n'])if self.add_last_action:inputs.append(batch['last_onehot_a_n'])if self.add_agent_id:agent_id_one_hot = torch.eye(self.N).unsqueeze(0).unsqueeze(0).repeat(self.batch_size, max_episode_len + 1, 1, 1)inputs.append(agent_id_one_hot)# inputs.shape=(bach_size,max_episode_len+1,N,input_dim)inputs = torch.cat([x for x in inputs], dim=-1)return inputsdef save_model(self, env_name, algorithm, number, seed, total_steps):torch.save(self.eval_Q_net.state_dict(), "./model/{}/{}_eval_rnn_number_{}_seed_{}_step_{}k.pth".format(env_name, algorithm, number, seed, int(total_steps / 1000)))
from normalization import Normalization
normalization .py文件
import numpy as npclass RunningMeanStd:# Dynamically calculate mean and stddef __init__(self, shape): # shape:the dimension of input dataself.n = 0self.mean = np.zeros(shape)self.S = np.zeros(shape)self.std = np.sqrt(self.S)def update(self, x):x = np.array(x)self.n += 1if self.n == 1:self.mean = xself.std = xelse:old_mean = self.mean.copy()self.mean = old_mean + (x - old_mean) / self.nself.S = self.S + (x - old_mean) * (x - self.mean)self.std = np.sqrt(self.S / self.n)class Normalization:def __init__(self, shape):self.running_ms = RunningMeanStd(shape=shape)def __call__(self, x, update=True):# Whether to update the mean and std,during the evaluating,update=Falseif update:self.running_ms.update(x)x = (x - self.running_ms.mean) / (self.running_ms.std + 1e-8)return xclass RewardScaling:def __init__(self, shape, gamma):self.shape = shape # reward shape=1self.gamma = gamma # discount factorself.running_ms = RunningMeanStd(shape=self.shape)self.R = np.zeros(self.shape)def __call__(self, x):self.R = self.gamma * self.R + xself.running_ms.update(self.R)x = x / (self.running_ms.std + 1e-8) # Only divided stdreturn xdef reset(self): # When an episode is done,we should reset 'self.R'self.R = np.zeros(self.shape)
from mix_net import QMIX_Net, VDN_Net
mix_net .py文件
import torch
import torch.nn.functional as Fclass QMIX_Net(nn.Module):def __init__(self, args):super(QMIX_Net, self).__init__()self.N = args.Nself.state_dim = args.state_dimself.batch_size = args.batch_sizeself.qmix_hidden_dim = args.qmix_hidden_dimself.hyper_hidden_dim = args.hyper_hidden_dimself.hyper_layers_num = args.hyper_layers_num"""w1:(N, qmix_hidden_dim)b1:(1, qmix_hidden_dim)w2:(qmix_hidden_dim, 1)b2:(1, 1)因为生成的hyper_w1需要是一个矩阵,而pytorch神经网络只能输出一个向量,所以就先输出长度为需要的 矩阵行*矩阵列 的向量,然后再转化成矩阵"""if self.hyper_layers_num == 2:print("hyper_layers_num=2")self.hyper_w1 = nn.Sequential(nn.Linear(self.state_dim, self.hyper_hidden_dim),nn.ReLU(),nn.Linear(self.hyper_hidden_dim, self.N * self.qmix_hidden_dim))self.hyper_w2 = nn.Sequential(nn.Linear(self.state_dim, self.hyper_hidden_dim),nn.ReLU(),nn.Linear(self.hyper_hidden_dim, self.qmix_hidden_dim * 1))elif self.hyper_layers_num == 1:print("hyper_layers_num=1")self.hyper_w1 = nn.Linear(self.state_dim, self.N * self.qmix_hidden_dim)self.hyper_w2 = nn.Linear(self.state_dim, self.qmix_hidden_dim * 1)else:print("wrong!!!")self.hyper_b1 = nn.Linear(self.state_dim, self.qmix_hidden_dim)self.hyper_b2 = nn.Sequential(nn.Linear(self.state_dim, self.qmix_hidden_dim),nn.ReLU(),nn.Linear(self.qmix_hidden_dim, 1))def forward(self, q, s):# q.shape(batch_size, max_episode_len, N)# s.shape(batch_size, max_episode_len,state_dim)q = q.view(-1, 1, self.N) # (batch_size * max_episode_len, 1, N)s = s.reshape(-1, self.state_dim) # (batch_size * max_episode_len, state_dim)w1 = torch.abs(self.hyper_w1(s)) # (batch_size * max_episode_len, N * qmix_hidden_dim)b1 = self.hyper_b1(s) # (batch_size * max_episode_len, qmix_hidden_dim)w1 = w1.view(-1, self.N, self.qmix_hidden_dim) # (batch_size * max_episode_len, N, qmix_hidden_dim)b1 = b1.view(-1, 1, self.qmix_hidden_dim) # (batch_size * max_episode_len, 1, qmix_hidden_dim)# torch.bmm: 3 dimensional tensor multiplicationq_hidden = F.elu(torch.bmm(q, w1) + b1) # (batch_size * max_episode_len, 1, qmix_hidden_dim)w2 = torch.abs(self.hyper_w2(s)) # (batch_size * max_episode_len, qmix_hidden_dim * 1)b2 = self.hyper_b2(s) # (batch_size * max_episode_len,1)w2 = w2.view(-1, self.qmix_hidden_dim, 1) # (batch_size * max_episode_len, qmix_hidden_dim, 1)b2 = b2.view(-1, 1, 1) # (batch_size * max_episode_len, 1, 1)q_total = torch.bmm(q_hidden, w2) + b2 # (batch_size * max_episode_len, 1, 1)q_total = q_total.view(self.batch_size, -1, 1) # (batch_size, max_episode_len, 1)return q_totalclass VDN_Net(nn.Module):def __init__(self, ):super(VDN_Net, self).__init__()def forward(self, q):return torch.sum(q, dim=-1, keepdim=True) # (batch_size, max_episode_len, 1)
文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。
