策略梯度算法
在前面的章节中,我们介绍了基于价值函数的一系列强化学习算法,包括 Q-learning、DQN 及其改进算法。这些方法的核心思想是学习值函数,然后基于值函数推导出最优策略。然而,在强化学习的广阔领域中,还存在另一类经典的方法——基于策略的方法。
基于价值函数的方法:
- 核心是学习状态价值函数 或动作价值函数
- 策略是隐式地从值函数中推导出来的(如 -greedy 策略)
- 代表性算法:Q-learning、DQN、Double DQN 等
基于策略的方法:
- 直接学习一个显式的参数化策略
- 策略本身就是学习的目标,值函数可能作为辅助
- 代表性算法:REINFORCE、Actor-Critic 等
策略梯度算法是基于策略方法的核心基础,本章将深入探讨这一重要技术。
1. 策略梯度
基于策略的方法首先需要将策略参数化。假设我们的目标策略 是一个随机策略,其中 是策略的参数向量。这个策略函数可以使用各种模型来表示:
- 线性模型:
- 神经网络模型:使用深度神经网络来建模复杂的策略函数
策略函数的输入是状态 ,输出是在该状态下各个动作的概率分布。
我们的目标是找到一个最优策略,使得在该策略下智能体在环境中获得的期望回报最大化。为此,我们定义目标函数为:
其中:
- 表示初始状态的分布
- 表示从初始状态 开始,遵循策略 所能获得的期望累积回报
为了优化目标函数 ,我们需要计算其关于参数 的梯度。
策略梯度定理的完整推导过程如下:
Details
-
初始形式:
- 目标函数梯度与状态访问分布 和动作价值函数 的加权和有关
- 表示在策略 下状态 的稳态分布
-
技巧性变换:
- 引入 并同时除以 ,保持等式不变
- 将求和形式转换为期望形式做准备
-
最终形式:
- 利用对数导数恒等式:
- 得到紧凑的期望形式,便于采样估计
由于期望的下标是 ,策略梯度算法必须使用当前策略采样得到的数据来计算梯度,因此它是在线策略算法。
策略梯度公式具有很直观的解释:
- 当某个动作 在状态 下具有较高的 值时,梯度更新会使策略 的概率增加
- 当某个动作 在状态 下具有较低的 值时,梯度更新会使策略 的概率减少
这种机制使得智能体更倾向于选择那些能够带来高回报的动作,如图所示。
1.1 REINFORCE算法
对于一个有限步数的环境,REINFORCE 算法中的策略梯度为:
其中:
- 是和环境交互的最大步数
- 是从时刻 开始的累积折扣回报,作为 的蒙特卡洛估计
- 期望 表示在策略 下轨迹的期望
以车杆环境(CartPole)为例:
- ,即每个回合最多进行 200 步
- 折扣因子 通常设置为 0.99 或 1.0
- 对于每个时间步 ,使用从该步到回合结束的实际累积回报来估计 值
REINFORCE 算法的伪代码如下:
输入:
- 学习率
- 折扣因子
- 初始策略参数
算法流程:
-
初始化策略参数
-
循环每个训练序列 :
a. 轨迹采样:使用当前策略 与环境交互,采样一条完整轨迹
b. 回报计算:对于轨迹中的每个时间步 ,计算从该时刻开始的累积折扣回报
c. 参数更新:对策略参数进行梯度上升更新
-
结束循环
1.2 策略梯度定理的证明
策略梯度定理的证明是强化学习理论中的重要组成部分。我们旨在证明:
其中 是目标函数。
我们从单个状态的价值函数梯度开始:
定义辅助函数:
定义 为从状态 出发,在 步后到达状态 的概率。
继续推导:
定义折扣状态访问分布:
回到目标函数:
至此,我们成功证明了策略梯度定理:
2. REINFORCE算法的实现
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt
from collections import deque
import random
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 策略网络
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super(PolicyNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.softmax(x, dim=-1)
def act(self, state):
state = torch.from_numpy(state).float().unsqueeze(0)
probs = self.forward(state)
m = Categorical(probs)
action = m.sample()
return action.item(), m.log_prob(action)
# REINFORCE 算法
class REINFORCE:
def __init__(self, state_dim, action_dim, learning_rate=1e-3, gamma=0.99):
self.policy_net = PolicyNetwork(state_dim, action_dim)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
self.gamma = gamma
self.saved_log_probs = []
self.rewards = []
def select_action(self, state):
action, log_prob = self.policy_net.act(state)
self.saved_log_probs.append(log_prob)
return action
def update_policy(self):
R = 0
policy_loss = []
returns = []
# 计算每个时间步的折扣回报
for r in self.rewards[::-1]:
R = r + self.gamma * R
returns.insert(0, R)
returns = torch.tensor(returns)
# 标准化回报以减少方差
returns = (returns - returns.mean()) / (returns.std() + 1e-9)
for log_prob, R in zip(self.saved_log_probs, returns):
policy_loss.append(-log_prob * R)
self.optimizer.zero_grad()
policy_loss = torch.cat(policy_loss).sum()
policy_loss.backward()
self.optimizer.step()
# 清空当前回合的数据
self.saved_log_probs = []
self.rewards = []
# 训练函数
def train_reinforce(env, agent, num_episodes=1000, max_steps=1000):
scores = []
scores_deque = deque(maxlen=100)
for i_episode in range(1, num_episodes+1):
state, _ = env.reset()
episode_reward = 0
for t in range(max_steps):
action = agent.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.rewards.append(reward)
episode_reward += reward
if done:
break
state = next_state
agent.update_policy()
scores.append(episode_reward)
scores_deque.append(episode_reward)
if i_episode % 100 == 0:
print('Episode {}\tAverage Score: {:.2f}'.format(
i_episode, np.mean(scores_deque)))
# 如果最近100个回合平均分达到195,认为问题已解决
if np.mean(scores_deque) >= 195.0:
print('Environment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(
i_episode-100, np.mean(scores_deque)))
break
return scores
# 可视化训练过程
def plot_training(scores):
plt.figure(figsize=(12, 5))
# 原始分数
plt.subplot(1, 2, 1)
plt.plot(scores)
plt.xlabel('