版權 2023 TF-Agents 作者群。
![]() |
![]() |
![]() |
![]() |
簡介
此範例示範如何使用 TF-Agents 程式庫在 Cartpole 環境中訓練類別 DQN (C51) 代理程式。
請務必先瀏覽 DQN 教學課程作為先備知識。本教學課程將假設您已熟悉 DQN 教學課程;主要重點將放在 DQN 和 C51 之間的差異。
設定
如果您尚未安裝 tf-agents,請執行
sudo apt-get update
sudo apt-get install -y xvfb ffmpeg freeglut3-dev
pip install 'imageio==2.4.0'
pip install pyvirtualdisplay
pip install tf-agents
pip install pyglet
pip install tf-keras
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import PIL.Image
import pyvirtualdisplay
import tensorflow as tf
from tf_agents.agents.categorical_dqn import categorical_dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import categorical_q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
超參數
env_name = "CartPole-v1" # @param {type:"string"}
num_iterations = 15000 # @param {type:"integer"}
initial_collect_steps = 1000 # @param {type:"integer"}
collect_steps_per_iteration = 1 # @param {type:"integer"}
replay_buffer_capacity = 100000 # @param {type:"integer"}
fc_layer_params = (100,)
batch_size = 64 # @param {type:"integer"}
learning_rate = 1e-3 # @param {type:"number"}
gamma = 0.99
log_interval = 200 # @param {type:"integer"}
num_atoms = 51 # @param {type:"integer"}
min_q_value = -20 # @param {type:"integer"}
max_q_value = 20 # @param {type:"integer"}
n_step_update = 2 # @param {type:"integer"}
num_eval_episodes = 10 # @param {type:"integer"}
eval_interval = 1000 # @param {type:"integer"}
環境
如同先前一樣載入環境,一個用於訓練,另一個用於評估。此處我們使用 CartPole-v1 (相較於 DQN 教學課程中的 CartPole-v0),其最大獎勵為 500 而非 200。
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
代理程式
C51 是一種以 DQN 為基礎的 Q 學習演算法。與 DQN 相同,它可以用於任何具有離散動作空間的環境。
C51 和 DQN 之間的主要差異在於,C51 並非只是預測每個狀態-動作對的 Q 值,而是預測 Q 值機率分佈的直方圖模型
藉由學習分佈而不只是預期值,演算法在訓練期間能夠保持更穩定,進而改善最終效能。在具有雙峰或甚至多峰值分佈的情況下尤其如此,因為單一平均值無法提供準確的圖像。
為了根據機率分佈而非值進行訓練,C51 必須執行一些複雜的分佈式運算才能計算其損失函數。但別擔心,所有這些都已在 TF-Agents 中為您處理好了!
若要建立 C51 代理程式,我們首先需要建立 CategoricalQNetwork
。CategoricalQNetwork 的 API 與 QNetwork
的 API 相同,但有一個額外的引數 num_atoms
。這表示我們機率分佈估計中的支援點數量。(上圖包含 10 個支援點,每個點都以垂直藍色長條表示。) 從名稱可以看出,預設的原子數為 51。
categorical_q_net = categorical_q_network.CategoricalQNetwork(
train_env.observation_spec(),
train_env.action_spec(),
num_atoms=num_atoms,
fc_layer_params=fc_layer_params)
我們還需要一個最佳化工具來訓練我們剛建立的網路,以及一個 train_step_counter
變數來追蹤網路更新的次數。
請注意,與原始 DqnAgent
的另一個顯著差異是,我們現在需要指定 min_q_value
和 max_q_value
作為引數。這些指定支援的最極端值 (換句話說,是任一側 51 個原子中最極端的值)。請務必針對您的特定環境適當選擇這些值。此處我們使用 -20 和 20。
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)
agent = categorical_dqn_agent.CategoricalDqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
categorical_q_network=categorical_q_net,
optimizer=optimizer,
min_q_value=min_q_value,
max_q_value=max_q_value,
n_step_update=n_step_update,
td_errors_loss_fn=common.element_wise_squared_loss,
gamma=gamma,
train_step_counter=train_step_counter)
agent.initialize()
最後要注意的一點是,我們還新增了一個引數,以搭配 \(n\) = 2 使用 n 步更新。在單步 Q 學習中 (\(n\) = 1),我們僅使用單步回傳 (根據貝爾曼最佳化方程式) 計算目前時間步和下一個時間步的 Q 值之間的誤差。單步回傳定義為
\(G_t = R_{t + 1} + \gamma V(s_{t + 1})\)
其中我們定義 \(V(s) = \max_a{Q(s, a)}\)。
N 步更新涉及將標準單步回傳函數展開 \(n\) 次
\(G_t^n = R_{t + 1} + \gamma R_{t + 2} + \gamma^2 R_{t + 3} + \dots + \gamma^n V(s_{t + n})\)
N 步更新讓代理程式能夠從更遙遠的未來進行啟動,而且在 \(n\) 值正確的情況下,通常可以加快學習速度。
雖然 C51 和 n 步更新通常會與優先重播結合,以形成 Rainbow 代理程式的核心,但我們發現實作優先重播並未帶來可衡量的改善。此外,我們發現,當僅將我們的 C51 代理程式與 n 步更新結合時,我們的代理程式在我們測試過的 Atari 環境範例中,效能與其他 Rainbow 代理程式一樣好。
指標與評估
用來評估政策最常見的指標是平均回傳。回傳是在環境中執行政策一個 episode 所獲得的獎勵總和,而我們通常會將其在幾個 episode 中取平均值。我們可以如下計算平均回傳指標。
def compute_avg_return(environment, policy, num_episodes=10):
total_return = 0.0
for _ in range(num_episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / num_episodes
return avg_return.numpy()[0]
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
train_env.action_spec())
compute_avg_return(eval_env, random_policy, num_eval_episodes)
# Please also see the metrics module for standard implementations of different
# metrics.
37.7
資料收集
如同 DQN 教學課程中一樣,使用隨機政策設定重播緩衝區和初始資料收集。
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=train_env.batch_size,
max_length=replay_buffer_capacity)
def collect_step(environment, policy):
time_step = environment.current_time_step()
action_step = policy.action(time_step)
next_time_step = environment.step(action_step.action)
traj = trajectory.from_transition(time_step, action_step, next_time_step)
# Add trajectory to the replay buffer
replay_buffer.add_batch(traj)
for _ in range(initial_collect_steps):
collect_step(train_env, random_policy)
# This loop is so common in RL, that we provide standard implementations of
# these. For more details see the drivers module.
# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
num_parallel_calls=3, sample_batch_size=batch_size,
num_steps=n_step_update + 1).prefetch(3)
iterator = iter(dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:377: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version. Instructions for updating: Use `as_dataset(..., single_deterministic_pass=False) instead.
訓練代理程式
訓練迴圈包含從環境收集資料,以及最佳化代理程式的網路。在此過程中,我們會偶爾評估代理程式的政策,以瞭解我們的進度。
以下執行時間約為 7 分鐘。
try:
%%time
except:
pass
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)
# Reset the train step
agent.train_step_counter.assign(0)
# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]
for _ in range(num_iterations):
# Collect a few steps using collect_policy and save to the replay buffer.
for _ in range(collect_steps_per_iteration):
collect_step(train_env, agent.collect_policy)
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(iterator)
train_loss = agent.train(experience)
step = agent.train_step_counter.numpy()
if step % log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss.loss))
if step % eval_interval == 0:
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
print('step = {0}: Average Return = {1:.2f}'.format(step, avg_return))
returns.append(avg_return)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:1260: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.foldr(fn, elems, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems)) step = 200: loss = 3.2159409523010254 step = 400: loss = 2.422974109649658 step = 600: loss = 1.9803032875061035 step = 800: loss = 1.733839750289917 step = 1000: loss = 1.705157995223999 step = 1000: Average Return = 88.60 step = 1200: loss = 1.655350923538208 step = 1400: loss = 1.419114351272583 step = 1600: loss = 1.2578476667404175 step = 1800: loss = 1.3189895153045654 step = 2000: loss = 0.9676651954650879 step = 2000: Average Return = 130.80 step = 2200: loss = 0.7909003496170044 step = 2400: loss = 0.9291537404060364 step = 2600: loss = 0.8300429582595825 step = 2800: loss = 0.9739845991134644 step = 3000: loss = 0.5435967445373535 step = 3000: Average Return = 261.40 step = 3200: loss = 0.7065144777297974 step = 3400: loss = 0.8492055535316467 step = 3600: loss = 0.808651864528656 step = 3800: loss = 0.48259130120277405 step = 4000: loss = 0.9187874794006348 step = 4000: Average Return = 280.90 step = 4200: loss = 0.7415772676467896 step = 4400: loss = 0.621947169303894 step = 4600: loss = 0.5226543545722961 step = 4800: loss = 0.7011302709579468 step = 5000: loss = 0.7732619047164917 step = 5000: Average Return = 271.70 step = 5200: loss = 0.8493011593818665 step = 5400: loss = 0.6786139011383057 step = 5600: loss = 0.5639233589172363 step = 5800: loss = 0.48468759655952454 step = 6000: loss = 0.6366198062896729 step = 6000: Average Return = 350.70 step = 6200: loss = 0.4855012893676758 step = 6400: loss = 0.4458327889442444 step = 6600: loss = 0.6745614409446716 step = 6800: loss = 0.5021890997886658 step = 7000: loss = 0.4639193117618561 step = 7000: Average Return = 343.00 step = 7200: loss = 0.4711253345012665 step = 7400: loss = 0.5891958475112915 step = 7600: loss = 0.3957907557487488 step = 7800: loss = 0.4868921637535095 step = 8000: loss = 0.5140666365623474 step = 8000: Average Return = 396.10 step = 8200: loss = 0.6051771640777588 step = 8400: loss = 0.6179391741752625 step = 8600: loss = 0.5253893733024597 step = 8800: loss = 0.3697047531604767 step = 9000: loss = 0.7271263599395752 step = 9000: Average Return = 320.20 step = 9200: loss = 0.5285177826881409 step = 9400: loss = 0.4590812921524048 step = 9600: loss = 0.4743385910987854 step = 9800: loss = 0.47938746213912964 step = 10000: loss = 0.5290409326553345 step = 10000: Average Return = 433.00 step = 10200: loss = 0.4573556184768677 step = 10400: loss = 0.352144718170166 step = 10600: loss = 0.39160820841789246 step = 10800: loss = 0.3254846930503845 step = 11000: loss = 0.37145161628723145 step = 11000: Average Return = 414.60 step = 11200: loss = 0.382583349943161 step = 11400: loss = 0.44465434551239014 step = 11600: loss = 0.4484185576438904 step = 11800: loss = 0.248131662607193 step = 12000: loss = 0.5516679883003235 step = 12000: Average Return = 375.40 step = 12200: loss = 0.3307253420352936 step = 12400: loss = 0.19486135244369507 step = 12600: loss = 0.31668007373809814 step = 12800: loss = 0.4462052285671234 step = 13000: loss = 0.241848886013031 step = 13000: Average Return = 326.80 step = 13200: loss = 0.20919030904769897 step = 13400: loss = 0.2044396996498108 step = 13600: loss = 0.428558886051178 step = 13800: loss = 0.1880824714899063 step = 14000: loss = 0.34256821870803833 step = 14000: Average Return = 345.50 step = 14200: loss = 0.22452744841575623 step = 14400: loss = 0.29694461822509766 step = 14600: loss = 0.4149337410926819 step = 14800: loss = 0.41922691464424133 step = 15000: loss = 0.4064670205116272 step = 15000: Average Return = 242.10
視覺化
繪圖
我們可以繪製回傳與全域步數的關係圖,以查看代理程式的效能。在 Cartpole-v1
中,環境會在每次桿子保持直立的步數時給予 +1 的獎勵,而且由於最大步數為 500,因此最大可能回傳也為 500。
steps = range(0, num_iterations + 1, eval_interval)
plt.plot(steps, returns)
plt.ylabel('Average Return')
plt.xlabel('Step')
plt.ylim(top=550)
(-11.255000400543214, 550.0)
影片
藉由在每個步驟中轉譯環境,將代理程式的效能視覺化會很有幫助。在執行此操作之前,我們先建立一個函數,將影片嵌入到此 Colab 中。
def embed_mp4(filename):
"""Embeds an mp4 file in the notebook."""
video = open(filename,'rb').read()
b64 = base64.b64encode(video)
tag = '''
<video width="640" height="480" controls>
<source src="data:video/mp4;base64,{0}" type="video/mp4">
Your browser does not support the video tag.
</video>'''.format(b64.decode())
return IPython.display.HTML(tag)
以下程式碼將代理程式的政策視覺化幾個 episode
num_episodes = 3
video_filename = 'imageio.mp4'
with imageio.get_writer(video_filename, fps=60) as video:
for _ in range(num_episodes):
time_step = eval_env.reset()
video.append_data(eval_py_env.render())
while not time_step.is_last():
action_step = agent.policy.action(time_step)
time_step = eval_env.step(action_step.action)
video.append_data(eval_py_env.render())
embed_mp4(video_filename)
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned. [swscaler @ 0x55f48c41d880] Warning: data is not aligned! This can lead to a speed loss
C51 在 CartPole-v1 上的表現往往略優於 DQN,但隨著環境變得越來越複雜,兩個代理程式之間的差異也變得越來越顯著。例如,在完整的 Atari 2600 基準測試中,在相對於隨機代理程式進行正規化之後,C51 的平均分數比 DQN 提高了 126%。透過加入 n 步更新,可以獲得額外的改進。
若要深入瞭解 C51 演算法,請參閱《A Distributional Perspective on Reinforcement Learning》(2017)。