跳到主要内容

第十章 Actor-Critic 算法

10.1 简介

在前面的章节中,我们分别学习了基于值函数的方法(如DQN)和基于策略的方法(如REINFORCE)。这两类方法各有特点:基于值函数的方法主要学习价值函数,策略是隐式地从值函数中推导出来的; 基于策略的方法:直接学习策略函数,不依赖显式的值函数。

一个自然的问题是:能否结合两者的优势,同时学习价值函数和策略函数?这就是 Actor-Critic 算法的核心思想。

Actor-Critic 算法本质上是基于策略的算法,因为其最终目标是优化参数化策略。价值函数的学习是为了辅助策略函数更好地学习,属于算法架构中的辅助组件。

10.2 Actor-Critic 算法原理

10.2.1 从 REINFORCE 到 Actor-Critic

回顾 REINFORCE 算法,其策略梯度为:

θJ(θ)=E[t=0TGtθlogπθ(atst)]\nabla_\theta J(\theta) = \mathbb{E} \left[ \sum_{t=0}^{T} G_t \nabla_\theta \log \pi_\theta(a_t|s_t) \right]

其中 GtG_t 是蒙特卡洛回报估计。REINFORCE 的主要问题是:高方差:蒙特卡洛估计的方差较大;更新延迟:需要等待回合结束才能更新;效率低下:只能用于有限步数的任务。

Actor-Critic 算法通过引入价值函数估计来解决这些问题。

10.2.2 广义策略梯度形式

策略梯度可以表示为更一般的形式:

g=E[t=0Tψtθlogπθ(atst)]g = \mathbb{E} \left[ \sum_{t=0}^{T} \psi_t \nabla_\theta \log \pi_\theta (a_t | s_t) \right]

其中 ψt\psi_t 可以是多种形式:

  1. t=0Tγtrt\sum_{t'=0}^{T} \gamma^{t'} r_{t'}:轨迹的总回报(REINFORCE)
  2. t=tTγttrt\sum_{t'=t}^{T} \gamma^{t'-t} r_{t'}:动作 ata_t 之后的回报
  3. t=tTγttrtb(st)\sum_{t'=t}^{T} \gamma^{t'-t} r_{t'} - b(s_t):带基线的回报
  4. Qπ(st,at)Q^\pi(s_t, a_t):动作价值函数
  5. Aπ(st,at)A^\pi(s_t, a_t):优势函数
  6. rt+γVπ(st+1)Vπ(st)r_t + \gamma V^\pi(s_{t+1}) - V^\pi(s_t):时序差分残差

10.2.3 Actor-Critic 架构

我们将重点介绍形式(6),即时序差分残差。Actor-Critic 算法包含两个核心组件:

Actor(策略网络)。它负责与环境交互,在 Critic 的指导下通过策略梯度学习更好的策略,输出动作的概率分布。

Critic(价值网络)。学习状态价值函数 V(s)V(s),评估当前策略的好坏,为 Actor 提供学习信号。

两者关系如图所示:Actor 基于 Critic 的评估来改进策略,Critic 基于 Actor 收集的数据来改进价值估计。

Actor-Critic 架构

10.2.4 算法优势

相比于 REINFORCE,Actor-Critic 具有以下优势:

  • 方差更小:用价值函数估计代替蒙特卡洛回报
  • 实时更新:可以在每一步之后进行更新
  • 适用范围广:不受任务步数限制
  • 样本效率高:可以重复利用经验

10.3 Actor-Critic 算法实现

10.3.1 Critic 的更新

Critic 价值网络表示为 Vϕ(s)V_\phi(s),参数为 ϕ\phi。我们使用时序差分学习来更新 Critic:

价值函数损失

L(ϕ)=12(rt+γVϕ(st+1)Vϕ(st))2\mathcal{L}(\phi) = \frac{1}{2} \left( r_t + \gamma V_{\phi^-}(s_{t+1}) - V_\phi(s_t) \right)^2

其中 VϕV_{\phi^-} 是目标网络,不参与梯度计算。

价值函数梯度

ϕL(ϕ)=(rt+γVϕ(st+1)Vϕ(st))ϕVϕ(st)\nabla_\phi \mathcal{L}(\phi) = - \left( r_t + \gamma V_{\phi^-}(s_{t+1}) - V_\phi(s_t) \right) \nabla_\phi V_\phi(s_t)

10.3.2 Actor 的更新

Actor 策略网络表示为 πθ(as)\pi_\theta(a|s),参数为 θ\theta。使用时序差分残差作为策略梯度的指导信号:

策略梯度

θJ(θ)=E[(rt+γVϕ(st+1)Vϕ(st))θlogπθ(atst)]\nabla_\theta J(\theta) = \mathbb{E} \left[ \left( r_t + \gamma V_{\phi^-}(s_{t+1}) - V_\phi(s_t) \right) \nabla_\theta \log \pi_\theta(a_t|s_t) \right]

10.3.3 算法流程

Actor-Critic 算法流程

  1. 初始化策略网络参数 θ\theta,价值网络参数 ϕ\phi

  2. 循环每个序列 e=1,2,,Ee = 1, 2, \dots, E

    a. 轨迹采样:用当前策略 πθ\pi_\theta 采样轨迹

    b. 时序差分计算:对每一步数据 (st,at,rt,st+1)(s_t, a_t, r_t, s_{t+1}) 计算: δt=rt+γVϕ(st+1)Vϕ(st)\delta_t = r_t + \gamma V_{\phi^-}(s_{t+1}) - V_\phi(s_t)

    c. Critic 更新:更新价值网络参数 ϕϕαϕδtϕVϕ(st)\phi \leftarrow \phi - \alpha_\phi \delta_t \nabla_\phi V_\phi(s_t)

    d. Actor 更新:更新策略网络参数 θθ+αθδtθlogπθ(atst)\theta \leftarrow \theta + \alpha_\theta \delta_t \nabla_\theta \log \pi_\theta(a_t|s_t)

  3. 结束循环

10.3.4 Python 实现框架

import torch
import torch.nn as nn
import torch.optim as optim

class Actor(nn.Module):
"""策略网络"""
def __init__(self, state_dim, action_dim):
super(Actor, self).__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, action_dim),
nn.Softmax(dim=-1)
)

def forward(self, state):
return self.network(state)

class Critic(nn.Module):
"""价值网络"""
def __init__(self, state_dim):
super(Critic, self).__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 1)
)

def forward(self, state):
return self.network(state)

class ActorCritic:
def __init__(self, state_dim, action_dim, lr_actor=0.001, lr_critic=0.001, gamma=0.99):
self.actor = Actor(state_dim, action_dim)
self.critic = Critic(state_dim)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr_critic)
self.gamma = gamma

def update(self, state, action, reward, next_state, done):
# 转换为张量
state = torch.FloatTensor(state)
next_state = torch.FloatTensor(next_state)
reward = torch.FloatTensor([reward])

# Critic 更新
current_value = self.critic(state)
next_value = torch.zeros(1) if done else self.critic(next_state)
target_value = reward + self.gamma * next_value.detach()
td_error = target_value - current_value

critic_loss = td_error.pow(2).mean()
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()

# Actor 更新
action_probs = self.actor(state)
log_prob = torch.log(action_probs[action])
actor_loss = -log_prob * td_error.detach()

self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()