当前位置: 首页 > news >正文

24/8/17算法笔记 AC算法

AC算法通常指的是Actor-Critic(演员-评论家)算法,它是强化学习中的一种算法框架,用于解决代理(agent)在环境中的决策问题。Actor-Critic方法结合了价值估计和策略优化,通常比纯粹的价值迭代或策略迭代方法具有更高的学习效率。

以下是AC算法的一些关键概念:

  1. Actor(演员):Actor负责生成动作,即在给定状态下选择一个动作。在深度学习中,Actor通常由一个神经网络实现,该网络根据当前状态输出一个动作或动作的概率分布。

  2. Critic(评论家):Critic负责评估Actor所采取的动作或策略的好坏。它通过估计从当前状态开始,遵循当前策略所能获得的预期回报来工作。Critic通常由一个价值函数或优势函数(advantage function)实现。

  3. 联合学习:在AC算法中,Actor和Critic是联合学习的。Actor生成动作,Critic评估这些动作,并提供反馈给Actor以改进其策略。

  4. 策略梯度:AC算法通常使用策略梯度方法来优化策略。这意味着Actor的参数通过梯度上升来更新,以增加长期回报的预期值。

  5. 优势函数:在某些AC算法中,如ACER(Actor-Critic with Experience Replay),使用优势函数来减少方差并提高学习稳定性。

  6. 经验回放:一些AC算法,如ACER,使用经验回放来提高数据效率,允许从单个经验中进行多次学习。

  7. 目标网络:为了增加训练过程的稳定性,AC算法可能会使用目标网络(target network),这是Actor或Critic网络的慢速更新副本。

  8. 探索:AC算法需要平衡探索和利用。Actor网络生成的动作不仅要考虑最大化回报,还要鼓励探索新的动作。

AC算法的常见变体包括:

  • A2C(Advantage Actor-Critic):引入优势函数来指导学习过程。
  • A3C(Asynchronous Advantage Actor-Critic):使用异步更新和多个并行工作者来提高学习效率。
  • ACER(Actor-Critic with Experience Replay):结合了经验回放和目标网络来提高稳定性和效率。

AC算法广泛应用于各种强化学习任务,包括游戏、机器人控制和自然语言处理等领域。

import gym
from matplotlib import pyplot as plt
%matplotlib inline
#创建环境
env = gym.make('CartPole-v1')
env.reset()#打印游戏
def show():plt.imshow(env.render(mode='rgb_array'))plt.show()

网络模型

import torch#定义模型
model = torch.nn.Sequential(torch.nn.Linear(4,128),torch.nn.ReLU(),torch.nn.Linear(128,2),torch.nn.Softmax(dim=1),
)
model_td = torch.nn.Sequential(torch.nn.Linear(4,128),torch.nn.ReLU(),torch.nn.Linear(128,1),
)
model(torch.randn(2,4)),model_td(torch.randn(2,4))

import random#得到一个动作
def get_action(state):state = torch.FloatTensor(state).reshape(1,4)prob = model(state)#根据概率选择一个动作action = random.choice(range(2),weights= prob[0].tolist(),k=1)[0]return action

获取数据

def get_data():states=[]rewards=[]actions = []next_state = []overs = []#初始化游戏state = env.reset()#玩到游戏结束为止over = Falsewhile not over:#根据当前状态得到一个动作action = get_action(state)#执行动作next_state,reward,over,_ = env.step(action)#记录数据样本states.append(state)rewards.append(reward)actions.append(action)next_states.append(next_state)overs.append(over)#更新状态,在世下一个动作state = next_statestates = torch.FloatTensor(states).reshape(-1,4)rewards = torch.FloatTensor(rewards).reshape(-1,1)actions = torch.LongTensor(actions).reshape(-1,1)next_states = torch.FloatTensor(next_states).reshape(-1,4)overs = torch.LongTensor(overs).reshaoe(-1,1)return states,rewards,actions,next_states,overs

测试函数

from IPython import displaydef test(play):#初始化游戏state = env.reset()#记录反馈值的和,这个值越大越好reward_sum= 0#玩到游戏结束为止over =Falsewhile not over:#根据当前状态得到一个动作action = get_action(state)#执行动作,得到反馈staet,reward,over,_=env.step(action)reward_sum+=reward#打印动画if play and random.random()<0.2:display.clear_output(wait=True)show()return  reward_sum

训练函数

def train():optimizer = torch.optim.Adan(model.parameters(),lr =1e-3)optimizer_td = torch.optim.Adam(model_td.parameters(),lr=1e-2)loss_fn = torch.nn.MSELoss()#玩N局游戏,每局训练一次for i in range(1000):#玩一局游戏,得到数据states,rewards,actions,next_states,overs = get_data()#计算values 和targetsvalues= model_td(states)targets = model_td(next_states)*0.98targets = (1-overs)targets+=rewards#时序差分误差delta = (targets-values).detach()#重新计算对应动作的概率probs = model(states)probs = probs.gather(dim=1,index=actions)#根据策略梯度算法的导函数实现#只是把公式中的reward_sum替换为了时序差分的误差loss = (-probs.log()*delta).mean()#时序差分的loss就是简单的value和target求mse loss即可loss_td = loss_fn(values,targets.detach())optimizer.zero_grad()   #作用是清除(重置)模型参数的梯度loss.backward()       #反向传播计算梯度的标准方法optimizer.step()     #更新模型的参数optimizer_td.zero_grad()  loss_td.backward()       optimizer_td.step()  if i %100 ==0:test_result = sum([test(play=False)for _ in range(10)])/10print(i,test_result)


http://www.mrgr.cn/news/285.html

相关文章:

  • [STM32]如何正确的安装和配置keil?(详细)
  • STM32标准库学习笔记-4.定时器中断
  • 希尔排序 java
  • 用爬虫玩转石墨文档
  • 初探 Rust 语言与环境搭建
  • 【myz_tools】Python库 myz_tools:Python算法及文档自动化生成工具
  • 常用的数据结构有哪些?
  • pywebview 入门
  • 生物药物分离与纯化技术pdf文件分享
  • arm 的寄存器概述(8)
  • 哪些情况下你需要Turnitin查重,确保原创性?
  • Hive3:常用查询语句整理
  • 学习笔记第二十六天
  • Codeforces Round 965 (Div. 2)
  • redis list类型
  • C++流媒体面试题
  • 启动nginx报错
  • 剪映怎么剪辑视频?2024年剪辑软件精选!
  • vscode 阅读linux内核(vscode+clangd)
  • pdf查看密码