驅動程式

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

簡介

強化學習中的常見模式是在環境中執行政策,以進行指定的步數或情節數。例如,這會發生在資料收集、評估和產生代理程式影片期間。

雖然這在 Python 中撰寫相對簡單,但在 TensorFlow 中撰寫和偵錯卻複雜得多,因為它涉及 tf.while 迴圈、tf.condtf.control_dependencies。因此,我們將執行迴圈的概念抽象化為名為 driver 的類別,並在 Python 和 TensorFlow 中提供經過完善測試的實作。

此外,驅動程式在每個步驟中遇到的資料會儲存在名為 Trajectory 的具名元組中,並廣播到一組觀察器 (例如重播緩衝區和指標)。此資料包括來自環境的觀察、政策建議的動作、獲得的獎勵、目前和下一個步驟的類型等等。

設定

如果您尚未安裝 tf-agents 或 gym,請執行

pip install tf-agents
pip install tf-keras
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.policies import random_py_policy
from tf_agents.policies import random_tf_policy
from tf_agents.metrics import py_metrics
from tf_agents.metrics import tf_metrics
from tf_agents.drivers import py_driver
from tf_agents.drivers import dynamic_episode_driver

Python 驅動程式

PyDriver 類別會採用 Python 環境、Python 政策和觀察器清單,以便在每個步驟中更新。主要方法是 run(),它會使用政策中的動作逐步執行環境,直到至少符合下列其中一個終止條件:步數達到 max_steps 或情節數達到 max_episodes

實作大致如下

class PyDriver(object):

  def __init__(self, env, policy, observers, max_steps=1, max_episodes=1):
    self._env = env
    self._policy = policy
    self._observers = observers or []
    self._max_steps = max_steps or np.inf
    self._max_episodes = max_episodes or np.inf

  def run(self, time_step, policy_state=()):
    num_steps = 0
    num_episodes = 0
    while num_steps < self._max_steps and num_episodes < self._max_episodes:

      # Compute an action using the policy for the given time_step
      action_step = self._policy.action(time_step, policy_state)

      # Apply the action to the environment and get the next step
      next_time_step = self._env.step(action_step.action)

      # Package information into a trajectory
      traj = trajectory.Trajectory(
         time_step.step_type,
         time_step.observation,
         action_step.action,
         action_step.info,
         next_time_step.step_type,
         next_time_step.reward,
         next_time_step.discount)

      for observer in self._observers:
        observer(traj)

      # Update statistics to check termination
      num_episodes += np.sum(traj.is_last())
      num_steps += np.sum(~traj.is_boundary())

      time_step = next_time_step
      policy_state = action_step.state

    return time_step, policy_state

現在,讓我們執行在 CartPole 環境中執行隨機政策、將結果儲存到重播緩衝區並計算一些指標的範例。

env = suite_gym.load('CartPole-v0')
policy = random_py_policy.RandomPyPolicy(time_step_spec=env.time_step_spec(), 
                                         action_spec=env.action_spec())
replay_buffer = []
metric = py_metrics.AverageReturnMetric()
observers = [replay_buffer.append, metric]
driver = py_driver.PyDriver(
    env, policy, observers, max_steps=20, max_episodes=1)

initial_time_step = env.reset()
final_time_step, _ = driver.run(initial_time_step)

print('Replay Buffer:')
for traj in replay_buffer:
  print(traj)

print('Average Return: ', metric.result())
Replay Buffer:
Trajectory(
{'step_type': array(0, dtype=int32),
 'observation': array([ 0.00374074, -0.02818722, -0.02798625, -0.0196638 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([ 0.00317699,  0.16732468, -0.02837953, -0.3210437 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([ 0.00652349, -0.02738187, -0.0348004 , -0.03744393], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([ 0.00597585, -0.22198795, -0.03554928,  0.24405919], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([ 0.00153609, -0.41658458, -0.0306681 ,  0.5253204 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.0067956 , -0.61126184, -0.02016169,  0.80818397], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.01902084, -0.8061018 , -0.00399801,  1.0944574 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.03514287, -0.6109274 ,  0.01789114,  0.8005227 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.04736142, -0.8062901 ,  0.03390159,  1.0987796 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.06348722, -0.61163044,  0.05587719,  0.816923  ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.07571983, -0.41731614,  0.07221565,  0.54232585], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.08406615, -0.61337477,  0.08306216,  0.8568603 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.09633365, -0.8095243 ,  0.10019937,  1.1744623 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.11252414, -0.6158369 ,  0.12368862,  0.91479784], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.12484087, -0.8123951 ,  0.14198457,  1.2436544 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.14108877, -0.61935145,  0.16685766,  0.9986062 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.1534758 , -0.42680538,  0.18682979,  0.7626272 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.1620119 , -0.23468053,  0.20208232,  0.5340639 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(2, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(0., dtype=float32)})
Trajectory(
{'step_type': array(2, dtype=int32),
 'observation': array([-0.16670552, -0.43198496,  0.21276361,  0.8830067 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(0, dtype=int32),
 'reward': array(0., dtype=float32),
 'discount': array(1., dtype=float32)})
Average Return:  18.0

TensorFlow 驅動程式

我們也有 TensorFlow 中的驅動程式,其功能與 Python 驅動程式類似,但使用 TF 環境、TF 政策、TF 觀察器等。我們目前有 2 個 TensorFlow 驅動程式:DynamicStepDriver,它會在給定數量的 (有效) 環境步驟後終止;以及 DynamicEpisodeDriver,它會在給定數量的情節後終止。讓我們看看 DynamicEpisode 的實際範例。

env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)

tf_policy = random_tf_policy.RandomTFPolicy(action_spec=tf_env.action_spec(),
                                            time_step_spec=tf_env.time_step_spec())


num_episodes = tf_metrics.NumberOfEpisodes()
env_steps = tf_metrics.EnvironmentSteps()
observers = [num_episodes, env_steps]
driver = dynamic_episode_driver.DynamicEpisodeDriver(
    tf_env, tf_policy, observers, num_episodes=2)

# Initial driver.run will reset the environment and initialize the policy.
final_time_step, policy_state = driver.run()

print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
final_time_step TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.0367443 ,  0.00652178,  0.04001181, -0.00376746]],
      dtype=float32)>})
Number of Steps:  34
Number of Episodes:  2
# Continue running from previous state
final_time_step, _ = driver.run(final_time_step, policy_state)

print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
final_time_step TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.04702466, -0.04836502,  0.01751254, -0.00393545]],
      dtype=float32)>})
Number of Steps:  63
Number of Episodes:  4