第十章 Actor-Critic 算法
10.1 简介
在前面的章节中,我们分别学习了基于值函数的方法(如DQN)和基于策略的方法(如REINFORCE)。这两类方法各有特点:基于值函数的方法主要学习价值函数,策略是隐式地从值函数中推导出来的; 基于策略的方法:直接学习策略函数,不依赖显式的值函数。
一个自然的问题是:能否结合两者的优势,同时学习价值函数和策略函数?这就是 Actor-Critic 算法的核心思想。
Actor-Critic 算法本质上是基于策略的算法,因为其最终目标是优化参数化策略。价值函数的学习是为 了辅助策略函数更好地学习,属于算法架构中的辅助组件。
10.2 Actor-Critic 算法原理
10.2.1 从 REINFORCE 到 Actor-Critic
回顾 REINFORCE 算法,其策略梯度为:
其中 是蒙特卡洛回报估计。REINFORCE 的主要问题是:高方差:蒙特卡洛估计的方差较大;更新延迟:需要等待回合结束才能更新;效率低下:只能用于有限步数的任务。
Actor-Critic 算法通过引入价值函数估计来解决这些问题。
10.2.2 广义策略梯度形式
策略梯度可以表示为更一般的形式:
其中 可以是多种形式:
- :轨迹的总回报(REINFORCE)
- :动作 之后的回报
- :带基线的回报
- :动作价值函数
- :优势函数
- :时序差分残差
10.2.3 Actor-Critic 架构
我们将重点介绍形式(6),即时序差分残差。Actor-Critic 算法包含两个核心组件:
Actor(策略网络)。它负责与环境交互,在 Critic 的指导下通过策略梯度学习更好的策略,输出动作的概率分布。
Critic(价值网络)。学习状态价值函数 ,评估当前策略的好坏,为 Actor 提供学习信号。
两者关系如图所示:Actor 基于 Critic 的评估来改进策略,Critic 基于 Actor 收集的数据来改进价值估计。
10.2.4 算法优势
相比于 REINFORCE,Actor-Critic 具有以下优势:
- 方差更小:用价值函数估计代替蒙特卡洛回报
- 实时更新:可以在每一步之后进行更新
- 适用范围广:不受任务步数限制
- 样本效率高:可以重复利用经验
10.3 Actor-Critic 算法实现
10.3.1 Critic 的更新
Critic 价值网络表示为 ,参数为 。我们使用时序差分学习来更新 Critic:
价值函数损失:
其中 是目标网络,不参与梯度计算。
价值函数梯度:
10.3.2 Actor 的更新
Actor 策略网络表示为 ,参数为 。使用时序差分残差作为策略梯度的指导信号:
策略梯度:
10.3.3 算法流程
Actor-Critic 算法流程:
-
初始化策略网络参数 ,价值网络参数
-
循环每个序列 :
a. 轨迹采样:用当前策略 采样轨迹
b. 时序差分计算:对每一步数据 计算:
c. Critic 更新:更新价值网络参数
d. Actor 更新:更新策略网络参数
-
结束循环
10.3.4 Python 实现框架
import torch
import torch.nn as nn
import torch.optim as optim
class Actor(nn.Module):
"""策略网络"""
def __init__(self, state_dim, action_dim):
super(Actor, self).__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, action_dim),
nn.Softmax(dim=-1)
)
def forward(self, state):
return self.network(state)
class Critic(nn.Module):
"""价值网络"""
def __init__(self, state_dim):
super(Critic, self).__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, state):
return self.network(state)
class ActorCritic:
def __init__(self, state_dim, action_dim, lr_actor=0.001, lr_critic=0.001, gamma=0.99):
self.actor = Actor(state_dim, action_dim)
self.critic = Critic(state_dim)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr_critic)
self.gamma = gamma
def update(self, state, action, reward, next_state, done):
# 转换为张量
state = torch.FloatTensor(state)
next_state = torch.FloatTensor(next_state)
reward = torch.FloatTensor([reward])
# Critic 更新
current_value = self.critic(state)
next_value = torch.zeros(1) if done else self.critic(next_state)
target_value = reward + self.gamma * next_value.detach()
td_error = target_value - current_value
critic_loss = td_error.pow(2).mean()
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Actor 更新
action_probs = self.actor(state)
log_prob = torch.log(action_probs[action])
actor_loss = -log_prob * td_error.detach()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()