REINFORCE 代理程式

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

簡介

此範例示範如何使用 TF-Agents 程式庫,在 Cartpole 環境中訓練 REINFORCE 代理程式,類似於 DQN 教學課程

Cartpole environment

我們將逐步引導您完成強化學習 (RL) 管線中的所有元件,以進行訓練、評估和資料收集。

設定

如果您尚未安裝下列依附元件,請執行

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg freeglut3-dev
pip install 'imageio==2.4.0'
pip install pyvirtualdisplay
pip install tf-agents[reverb]
pip install pyglet xvfbwrapper
pip install tf-keras
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import IPython
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay
import reverb

import tensorflow as tf

from tf_agents.agents.reinforce import reinforce_agent
from tf_agents.drivers import py_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import py_tf_eager_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
2023-12-22 14:05:03.363396: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-22 14:05:03.363443: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-22 14:05:03.365008: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

超參數

env_name = "CartPole-v0" # @param {type:"string"}
num_iterations = 250 # @param {type:"integer"}
collect_episodes_per_iteration = 2 # @param {type:"integer"}
replay_buffer_capacity = 2000 # @param {type:"integer"}

fc_layer_params = (100,)

learning_rate = 1e-3 # @param {type:"number"}
log_interval = 25 # @param {type:"integer"}
num_eval_episodes = 10 # @param {type:"integer"}
eval_interval = 50 # @param {type:"integer"}

環境

RL 中的環境代表我們嘗試解決的任務或問題。使用 suites 即可在 TF-Agents 中輕鬆建立標準環境。我們有不同的 suites,可從 OpenAI Gym、Atari、DM Control 等來源載入環境,並指定字串環境名稱。

現在讓我們從 OpenAI Gym 套件載入 CartPole 環境。

env = suite_gym.load(env_name)

我們可以轉譯此環境以查看其外觀。自由擺動的桿子連接到臺車。目標是左右移動臺車,以保持桿子指向上方。

env.reset()
PIL.Image.fromarray(env.render())

png

time_step = environment.step(action) 陳述式會在環境中取得 actionTimeStep tuple 傳回的值包含環境的下一個觀察結果和該動作的獎勵。環境中的 time_step_spec()action_spec() 方法會傳回 time_stepaction 的規格(類型、形狀、界限)。

print('Observation Spec:')
print(env.time_step_spec().observation)
print('Action Spec:')
print(env.action_spec())
Observation Spec:
BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])
Action Spec:
BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)

因此,我們看到觀察結果是一個包含 4 個浮點數的陣列:臺車的位置和速度,以及桿子的角位置和角速度。由於只有兩種可能的動作(向左移動或向右移動),因此 action_spec 是一個純量,其中 0 表示「向左移動」,1 表示「向右移動」。

time_step = env.reset()
print('Time step:')
print(time_step)

action = np.array(1, dtype=np.int32)

next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
Time step:
TimeStep(
{'step_type': array(0, dtype=int32),
 'reward': array(0., dtype=float32),
 'discount': array(1., dtype=float32),
 'observation': array([-0.00907558,  0.02627698, -0.01019297,  0.04808202], dtype=float32)})
Next time step:
TimeStep(
{'step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32),
 'observation': array([-0.00855004,  0.2215436 , -0.00923133, -0.24779937], dtype=float32)})

通常我們會建立兩個環境:一個用於訓練,另一個用於評估。大多數環境都是以純 Python 撰寫,但可以使用 TFPyEnvironment 包裝函式輕鬆轉換為 TensorFlow。TFPyEnvironment 會將原始環境的 API 使用的 numpy 陣列轉換為 Tensors,或從 Tensors 轉換為 numpy 陣列,讓您更輕鬆地與 TensorFlow 策略和代理程式互動。

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

代理程式

我們用來解決 RL 問題的演算法表示為 Agent。除了 REINFORCE 代理程式之外,TF-Agents 還提供各種 Agents 的標準實作,例如 DQNDDPGTD3PPOSAC

若要建立 REINFORCE 代理程式,我們首先需要一個 Actor Network,它可以學習預測在給定環境觀察結果時應採取的動作。

我們可以使用觀察結果和動作的規格輕鬆建立 Actor Network。我們可以指定網路中的層,在此範例中,網路中的層是設定為 ints tuple 的 fc_layer_params 引數,表示每個隱藏層的大小(請參閱上方的「超參數」章節)。

actor_net = actor_distribution_network.ActorDistributionNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

我們還需要一個 optimizer 來訓練我們剛建立的網路,以及一個 train_step_counter 變數來追蹤網路更新的次數。

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

tf_agent = reinforce_agent.ReinforceAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    actor_network=actor_net,
    optimizer=optimizer,
    normalize_returns=True,
    train_step_counter=train_step_counter)
tf_agent.initialize()

策略

在 TF-Agents 中,策略代表 RL 中策略的標準概念:在給定 time_step 的情況下產生動作或動作分佈。主要方法是 policy_step = policy.action(time_step),其中 policy_step 是具名 tuple PolicyStep(action, state, info)policy_step.action 是要套用至環境的 actionstate 代表有狀態 (RNN) 策略的狀態,而 info 可能包含輔助資訊,例如動作的對數機率。

代理程式包含兩個策略:主要策略 (agent.policy) 用於評估/部署,另一個策略 (agent.collect_policy) 用於資料收集。

eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy

指標和評估

用於評估策略的最常見指標是平均報酬。報酬是在劇集中執行環境中的策略時獲得的獎勵總和,我們通常會將其在幾個劇集中取平均值。我們可以按如下方式計算平均報酬指標。

def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


# Please also see the metrics module for standard implementations of different
# metrics.

重播緩衝區

為了追蹤從環境收集的資料,我們將使用 Reverb,這是 Deepmind 開發的高效率、可擴充且易於使用的重播系統。它會在我們收集軌跡時儲存經驗資料,並在訓練期間使用。

此重播緩衝區是使用規格建構而成,規格描述要儲存的張量,這些張量可以使用 tf_agent.collect_data_spec 從代理程式取得。

table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(
      tf_agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(
      replay_buffer_signature)
table = reverb.Table(
    table_name,
    max_size=replay_buffer_capacity,
    sampler=reverb.selectors.Uniform(),
    remover=reverb.selectors.Fifo(),
    rate_limiter=reverb.rate_limiters.MinSize(1),
    signature=replay_buffer_signature)

reverb_server = reverb.Server([table])

replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
    tf_agent.collect_data_spec,
    table_name=table_name,
    sequence_length=None,
    local_server=reverb_server)

rb_observer = reverb_utils.ReverbAddEpisodeObserver(
    replay_buffer.py_client,
    table_name,
    replay_buffer_capacity
)
[reverb/cc/platform/tfrecord_checkpointer.cc:162]  Initializing TFRecordCheckpointer in /tmpfs/tmp/tmpkagdqs1n.
[reverb/cc/platform/tfrecord_checkpointer.cc:565] Loading latest checkpoint from /tmpfs/tmp/tmpkagdqs1n
[reverb/cc/platform/default/server.cc:71] Started replay server on port 41705

對於大多數代理程式,collect_data_specTrajectory 具名 tuple,其中包含觀察結果、動作、獎勵等。

資料收集

由於 REINFORCE 從整個劇集中學習,因此我們定義一個函式,以使用給定的資料收集策略收集劇集,並將資料(觀察結果、動作、獎勵等)儲存為重播緩衝區中的軌跡。在這裡,我們使用「PyDriver」來執行經驗收集迴圈。您可以在我們的 驅動程式教學課程中瞭解更多關於 TF Agents 驅動程式的資訊。

def collect_episode(environment, policy, num_episodes):

  driver = py_driver.PyDriver(
    environment,
    py_tf_eager_policy.PyTFEagerPolicy(
      policy, use_tf_function=True),
    [rb_observer],
    max_episodes=num_episodes)
  initial_time_step = environment.reset()
  driver.run(initial_time_step)

訓練代理程式

訓練迴圈包含從環境收集資料和最佳化代理程式網路。在此過程中,我們會偶爾評估代理程式的策略,以查看我們的進展。

以下步驟約需 3 分鐘才能完成。

try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
tf_agent.train = common.function(tf_agent.train)

# Reset the train step
tf_agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few episodes using collect_policy and save to the replay buffer.
  collect_episode(
      train_py_env, tf_agent.collect_policy, collect_episodes_per_iteration)

  # Use data from the buffer and update the agent's network.
  iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))
  trajectories, _ = next(iterator)
  train_loss = tf_agent.train(experience=trajectories)

  replay_buffer.clear()

  step = tf_agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss.loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1703253913.189247   48625 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.
step = 25: loss = 1.8318419456481934
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.
step = 50: loss = 0.0070743560791015625
step = 50: Average Return = 9.800000190734863
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.
step = 75: loss = 1.1006038188934326
step = 100: loss = 0.5719594955444336
step = 100: Average Return = 50.29999923706055
step = 125: loss = -1.2458715438842773
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.
step = 150: loss = 1.9363441467285156
step = 150: Average Return = 98.30000305175781
step = 175: loss = 0.8784818649291992
step = 200: loss = 1.9726766347885132
step = 200: Average Return = 143.6999969482422
step = 225: loss = 2.316105842590332
step = 250: loss = 2.5175299644470215
step = 250: Average Return = 191.5

視覺化

繪圖

我們可以繪製報酬與全域步數的關係圖,以查看代理程式的效能。在 Cartpole-v0 中,環境會在桿子保持豎立的每個時間步給予 +1 的獎勵,由於最大步數為 200,因此最大可能的報酬也是 200。

steps = range(0, num_iterations + 1, eval_interval)
plt.plot(steps, returns)
plt.ylabel('Average Return')
plt.xlabel('Step')
plt.ylim(top=250)
(0.7150002002716054, 250.0)

png

影片

透過在每個步驟轉譯環境來視覺化代理程式的效能會很有幫助。在執行此操作之前,我們先建立一個函式,將影片嵌入到此 Colab 中。

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

以下程式碼會視覺化代理程式在幾個劇集中的策略

num_episodes = 3
video_filename = 'imageio.mp4'
with imageio.get_writer(video_filename, fps=60) as video:
  for _ in range(num_episodes):
    time_step = eval_env.reset()
    video.append_data(eval_py_env.render())
    while not time_step.is_last():
      action_step = tf_agent.policy.action(time_step)
      time_step = eval_env.step(action_step.action)
      video.append_data(eval_py_env.render())

embed_mp4(video_filename)
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.
[swscaler @ 0x5563cf186880] Warning: data is not aligned! This can lead to a speed loss