DQN C51/Rainbow

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

簡介

此範例示範如何使用 TF-Agents 程式庫在 Cartpole 環境中訓練類別 DQN (C51) 代理程式。

Cartpole environment

請務必先瀏覽 DQN 教學課程作為先備知識。本教學課程將假設您已熟悉 DQN 教學課程;主要重點將放在 DQN 和 C51 之間的差異。

設定

如果您尚未安裝 tf-agents,請執行

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg freeglut3-dev
pip install 'imageio==2.4.0'
pip install pyvirtualdisplay
pip install tf-agents
pip install pyglet
pip install tf-keras
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import PIL.Image
import pyvirtualdisplay

import tensorflow as tf

from tf_agents.agents.categorical_dqn import categorical_dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import categorical_q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()

超參數

env_name = "CartPole-v1" # @param {type:"string"}
num_iterations = 15000 # @param {type:"integer"}

initial_collect_steps = 1000  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_capacity = 100000  # @param {type:"integer"}

fc_layer_params = (100,)

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
gamma = 0.99
log_interval = 200  # @param {type:"integer"}

num_atoms = 51  # @param {type:"integer"}
min_q_value = -20  # @param {type:"integer"}
max_q_value = 20  # @param {type:"integer"}
n_step_update = 2  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

環境

如同先前一樣載入環境,一個用於訓練,另一個用於評估。此處我們使用 CartPole-v1 (相較於 DQN 教學課程中的 CartPole-v0),其最大獎勵為 500 而非 200。

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

代理程式

C51 是一種以 DQN 為基礎的 Q 學習演算法。與 DQN 相同,它可以用於任何具有離散動作空間的環境。

C51 和 DQN 之間的主要差異在於,C51 並非只是預測每個狀態-動作對的 Q 值,而是預測 Q 值機率分佈的直方圖模型

Example C51 Distribution

藉由學習分佈而不只是預期值,演算法在訓練期間能夠保持更穩定,進而改善最終效能。在具有雙峰或甚至多峰值分佈的情況下尤其如此,因為單一平均值無法提供準確的圖像。

為了根據機率分佈而非值進行訓練,C51 必須執行一些複雜的分佈式運算才能計算其損失函數。但別擔心,所有這些都已在 TF-Agents 中為您處理好了!

若要建立 C51 代理程式,我們首先需要建立 CategoricalQNetwork。CategoricalQNetwork 的 API 與 QNetwork 的 API 相同,但有一個額外的引數 num_atoms。這表示我們機率分佈估計中的支援點數量。(上圖包含 10 個支援點,每個點都以垂直藍色長條表示。) 從名稱可以看出,預設的原子數為 51。

categorical_q_net = categorical_q_network.CategoricalQNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    num_atoms=num_atoms,
    fc_layer_params=fc_layer_params)

我們還需要一個最佳化工具來訓練我們剛建立的網路,以及一個 train_step_counter 變數來追蹤網路更新的次數。

請注意,與原始 DqnAgent 的另一個顯著差異是,我們現在需要指定 min_q_valuemax_q_value 作為引數。這些指定支援的最極端值 (換句話說,是任一側 51 個原子中最極端的值)。請務必針對您的特定環境適當選擇這些值。此處我們使用 -20 和 20。

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = categorical_dqn_agent.CategoricalDqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    categorical_q_network=categorical_q_net,
    optimizer=optimizer,
    min_q_value=min_q_value,
    max_q_value=max_q_value,
    n_step_update=n_step_update,
    td_errors_loss_fn=common.element_wise_squared_loss,
    gamma=gamma,
    train_step_counter=train_step_counter)
agent.initialize()

最後要注意的一點是,我們還新增了一個引數,以搭配 \(n\) = 2 使用 n 步更新。在單步 Q 學習中 (\(n\) = 1),我們僅使用單步回傳 (根據貝爾曼最佳化方程式) 計算目前時間步和下一個時間步的 Q 值之間的誤差。單步回傳定義為

\(G_t = R_{t + 1} + \gamma V(s_{t + 1})\)

其中我們定義 \(V(s) = \max_a{Q(s, a)}\)。

N 步更新涉及將標準單步回傳函數展開 \(n\) 次

\(G_t^n = R_{t + 1} + \gamma R_{t + 2} + \gamma^2 R_{t + 3} + \dots + \gamma^n V(s_{t + n})\)

N 步更新讓代理程式能夠從更遙遠的未來進行啟動,而且在 \(n\) 值正確的情況下,通常可以加快學習速度。

雖然 C51 和 n 步更新通常會與優先重播結合,以形成 Rainbow 代理程式的核心,但我們發現實作優先重播並未帶來可衡量的改善。此外,我們發現,當僅將我們的 C51 代理程式與 n 步更新結合時,我們的代理程式在我們測試過的 Atari 環境範例中,效能與其他 Rainbow 代理程式一樣好。

指標與評估

用來評估政策最常見的指標是平均回傳。回傳是在環境中執行政策一個 episode 所獲得的獎勵總和,而我們通常會將其在幾個 episode 中取平均值。我們可以如下計算平均回傳指標。

def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

compute_avg_return(eval_env, random_policy, num_eval_episodes)

# Please also see the metrics module for standard implementations of different
# metrics.
37.7

資料收集

如同 DQN 教學課程中一樣,使用隨機政策設定重播緩衝區和初始資料收集。

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_capacity)

def collect_step(environment, policy):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)

  # Add trajectory to the replay buffer
  replay_buffer.add_batch(traj)

for _ in range(initial_collect_steps):
  collect_step(train_env, random_policy)

# This loop is so common in RL, that we provide standard implementations of
# these. For more details see the drivers module.

# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, sample_batch_size=batch_size,
    num_steps=n_step_update + 1).prefetch(3)

iterator = iter(dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:377: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

訓練代理程式

訓練迴圈包含從環境收集資料,以及最佳化代理程式的網路。在此過程中,我們會偶爾評估代理程式的政策,以瞭解我們的進度。

以下執行時間約為 7 分鐘。

try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few steps using collect_policy and save to the replay buffer.
  for _ in range(collect_steps_per_iteration):
    collect_step(train_env, agent.collect_policy)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience)

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss.loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1:.2f}'.format(step, avg_return))
    returns.append(avg_return)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:1260: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
step = 200: loss = 3.2159409523010254
step = 400: loss = 2.422974109649658
step = 600: loss = 1.9803032875061035
step = 800: loss = 1.733839750289917
step = 1000: loss = 1.705157995223999
step = 1000: Average Return = 88.60
step = 1200: loss = 1.655350923538208
step = 1400: loss = 1.419114351272583
step = 1600: loss = 1.2578476667404175
step = 1800: loss = 1.3189895153045654
step = 2000: loss = 0.9676651954650879
step = 2000: Average Return = 130.80
step = 2200: loss = 0.7909003496170044
step = 2400: loss = 0.9291537404060364
step = 2600: loss = 0.8300429582595825
step = 2800: loss = 0.9739845991134644
step = 3000: loss = 0.5435967445373535
step = 3000: Average Return = 261.40
step = 3200: loss = 0.7065144777297974
step = 3400: loss = 0.8492055535316467
step = 3600: loss = 0.808651864528656
step = 3800: loss = 0.48259130120277405
step = 4000: loss = 0.9187874794006348
step = 4000: Average Return = 280.90
step = 4200: loss = 0.7415772676467896
step = 4400: loss = 0.621947169303894
step = 4600: loss = 0.5226543545722961
step = 4800: loss = 0.7011302709579468
step = 5000: loss = 0.7732619047164917
step = 5000: Average Return = 271.70
step = 5200: loss = 0.8493011593818665
step = 5400: loss = 0.6786139011383057
step = 5600: loss = 0.5639233589172363
step = 5800: loss = 0.48468759655952454
step = 6000: loss = 0.6366198062896729
step = 6000: Average Return = 350.70
step = 6200: loss = 0.4855012893676758
step = 6400: loss = 0.4458327889442444
step = 6600: loss = 0.6745614409446716
step = 6800: loss = 0.5021890997886658
step = 7000: loss = 0.4639193117618561
step = 7000: Average Return = 343.00
step = 7200: loss = 0.4711253345012665
step = 7400: loss = 0.5891958475112915
step = 7600: loss = 0.3957907557487488
step = 7800: loss = 0.4868921637535095
step = 8000: loss = 0.5140666365623474
step = 8000: Average Return = 396.10
step = 8200: loss = 0.6051771640777588
step = 8400: loss = 0.6179391741752625
step = 8600: loss = 0.5253893733024597
step = 8800: loss = 0.3697047531604767
step = 9000: loss = 0.7271263599395752
step = 9000: Average Return = 320.20
step = 9200: loss = 0.5285177826881409
step = 9400: loss = 0.4590812921524048
step = 9600: loss = 0.4743385910987854
step = 9800: loss = 0.47938746213912964
step = 10000: loss = 0.5290409326553345
step = 10000: Average Return = 433.00
step = 10200: loss = 0.4573556184768677
step = 10400: loss = 0.352144718170166
step = 10600: loss = 0.39160820841789246
step = 10800: loss = 0.3254846930503845
step = 11000: loss = 0.37145161628723145
step = 11000: Average Return = 414.60
step = 11200: loss = 0.382583349943161
step = 11400: loss = 0.44465434551239014
step = 11600: loss = 0.4484185576438904
step = 11800: loss = 0.248131662607193
step = 12000: loss = 0.5516679883003235
step = 12000: Average Return = 375.40
step = 12200: loss = 0.3307253420352936
step = 12400: loss = 0.19486135244369507
step = 12600: loss = 0.31668007373809814
step = 12800: loss = 0.4462052285671234
step = 13000: loss = 0.241848886013031
step = 13000: Average Return = 326.80
step = 13200: loss = 0.20919030904769897
step = 13400: loss = 0.2044396996498108
step = 13600: loss = 0.428558886051178
step = 13800: loss = 0.1880824714899063
step = 14000: loss = 0.34256821870803833
step = 14000: Average Return = 345.50
step = 14200: loss = 0.22452744841575623
step = 14400: loss = 0.29694461822509766
step = 14600: loss = 0.4149337410926819
step = 14800: loss = 0.41922691464424133
step = 15000: loss = 0.4064670205116272
step = 15000: Average Return = 242.10

視覺化

繪圖

我們可以繪製回傳與全域步數的關係圖,以查看代理程式的效能。在 Cartpole-v1 中,環境會在每次桿子保持直立的步數時給予 +1 的獎勵,而且由於最大步數為 500,因此最大可能回傳也為 500。

steps = range(0, num_iterations + 1, eval_interval)
plt.plot(steps, returns)
plt.ylabel('Average Return')
plt.xlabel('Step')
plt.ylim(top=550)
(-11.255000400543214, 550.0)

png

影片

藉由在每個步驟中轉譯環境,將代理程式的效能視覺化會很有幫助。在執行此操作之前,我們先建立一個函數,將影片嵌入到此 Colab 中。

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

以下程式碼將代理程式的政策視覺化幾個 episode

num_episodes = 3
video_filename = 'imageio.mp4'
with imageio.get_writer(video_filename, fps=60) as video:
  for _ in range(num_episodes):
    time_step = eval_env.reset()
    video.append_data(eval_py_env.render())
    while not time_step.is_last():
      action_step = agent.policy.action(time_step)
      time_step = eval_env.step(action_step.action)
      video.append_data(eval_py_env.render())

embed_mp4(video_filename)
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.
[swscaler @ 0x55f48c41d880] Warning: data is not aligned! This can lead to a speed loss

C51 在 CartPole-v1 上的表現往往略優於 DQN,但隨著環境變得越來越複雜,兩個代理程式之間的差異也變得越來越顯著。例如,在完整的 Atari 2600 基準測試中,在相對於隨機代理程式進行正規化之後,C51 的平均分數比 DQN 提高了 126%。透過加入 n 步更新,可以獲得額外的改進。

若要深入瞭解 C51 演算法,請參閱《A Distributional Perspective on Reinforcement Learning》(2017)