版權所有 2023 TF-Agents 作者。
![]() |
![]() |
![]() |
![]() |
簡介
強化學習中的常見模式是在環境中執行政策,以進行指定的步數或情節數。例如,這會發生在資料收集、評估和產生代理程式影片期間。
雖然這在 Python 中撰寫相對簡單,但在 TensorFlow 中撰寫和偵錯卻複雜得多,因為它涉及 tf.while
迴圈、tf.cond
和 tf.control_dependencies
。因此,我們將執行迴圈的概念抽象化為名為 driver
的類別,並在 Python 和 TensorFlow 中提供經過完善測試的實作。
此外,驅動程式在每個步驟中遇到的資料會儲存在名為 Trajectory 的具名元組中,並廣播到一組觀察器 (例如重播緩衝區和指標)。此資料包括來自環境的觀察、政策建議的動作、獲得的獎勵、目前和下一個步驟的類型等等。
設定
如果您尚未安裝 tf-agents 或 gym,請執行
pip install tf-agents
pip install tf-keras
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.policies import random_py_policy
from tf_agents.policies import random_tf_policy
from tf_agents.metrics import py_metrics
from tf_agents.metrics import tf_metrics
from tf_agents.drivers import py_driver
from tf_agents.drivers import dynamic_episode_driver
Python 驅動程式
PyDriver
類別會採用 Python 環境、Python 政策和觀察器清單,以便在每個步驟中更新。主要方法是 run()
,它會使用政策中的動作逐步執行環境,直到至少符合下列其中一個終止條件:步數達到 max_steps
或情節數達到 max_episodes
。
實作大致如下
class PyDriver(object):
def __init__(self, env, policy, observers, max_steps=1, max_episodes=1):
self._env = env
self._policy = policy
self._observers = observers or []
self._max_steps = max_steps or np.inf
self._max_episodes = max_episodes or np.inf
def run(self, time_step, policy_state=()):
num_steps = 0
num_episodes = 0
while num_steps < self._max_steps and num_episodes < self._max_episodes:
# Compute an action using the policy for the given time_step
action_step = self._policy.action(time_step, policy_state)
# Apply the action to the environment and get the next step
next_time_step = self._env.step(action_step.action)
# Package information into a trajectory
traj = trajectory.Trajectory(
time_step.step_type,
time_step.observation,
action_step.action,
action_step.info,
next_time_step.step_type,
next_time_step.reward,
next_time_step.discount)
for observer in self._observers:
observer(traj)
# Update statistics to check termination
num_episodes += np.sum(traj.is_last())
num_steps += np.sum(~traj.is_boundary())
time_step = next_time_step
policy_state = action_step.state
return time_step, policy_state
現在,讓我們執行在 CartPole 環境中執行隨機政策、將結果儲存到重播緩衝區並計算一些指標的範例。
env = suite_gym.load('CartPole-v0')
policy = random_py_policy.RandomPyPolicy(time_step_spec=env.time_step_spec(),
action_spec=env.action_spec())
replay_buffer = []
metric = py_metrics.AverageReturnMetric()
observers = [replay_buffer.append, metric]
driver = py_driver.PyDriver(
env, policy, observers, max_steps=20, max_episodes=1)
initial_time_step = env.reset()
final_time_step, _ = driver.run(initial_time_step)
print('Replay Buffer:')
for traj in replay_buffer:
print(traj)
print('Average Return: ', metric.result())
Replay Buffer: Trajectory( {'step_type': array(0, dtype=int32), 'observation': array([ 0.00374074, -0.02818722, -0.02798625, -0.0196638 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([ 0.00317699, 0.16732468, -0.02837953, -0.3210437 ], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([ 0.00652349, -0.02738187, -0.0348004 , -0.03744393], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([ 0.00597585, -0.22198795, -0.03554928, 0.24405919], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([ 0.00153609, -0.41658458, -0.0306681 , 0.5253204 ], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.0067956 , -0.61126184, -0.02016169, 0.80818397], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.01902084, -0.8061018 , -0.00399801, 1.0944574 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.03514287, -0.6109274 , 0.01789114, 0.8005227 ], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.04736142, -0.8062901 , 0.03390159, 1.0987796 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.06348722, -0.61163044, 0.05587719, 0.816923 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.07571983, -0.41731614, 0.07221565, 0.54232585], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.08406615, -0.61337477, 0.08306216, 0.8568603 ], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.09633365, -0.8095243 , 0.10019937, 1.1744623 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.11252414, -0.6158369 , 0.12368862, 0.91479784], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.12484087, -0.8123951 , 0.14198457, 1.2436544 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.14108877, -0.61935145, 0.16685766, 0.9986062 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.1534758 , -0.42680538, 0.18682979, 0.7626272 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.1620119 , -0.23468053, 0.20208232, 0.5340639 ], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(2, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(0., dtype=float32)}) Trajectory( {'step_type': array(2, dtype=int32), 'observation': array([-0.16670552, -0.43198496, 0.21276361, 0.8830067 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(0, dtype=int32), 'reward': array(0., dtype=float32), 'discount': array(1., dtype=float32)}) Average Return: 18.0
TensorFlow 驅動程式
我們也有 TensorFlow 中的驅動程式,其功能與 Python 驅動程式類似,但使用 TF 環境、TF 政策、TF 觀察器等。我們目前有 2 個 TensorFlow 驅動程式:DynamicStepDriver
,它會在給定數量的 (有效) 環境步驟後終止;以及 DynamicEpisodeDriver
,它會在給定數量的情節後終止。讓我們看看 DynamicEpisode 的實際範例。
env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)
tf_policy = random_tf_policy.RandomTFPolicy(action_spec=tf_env.action_spec(),
time_step_spec=tf_env.time_step_spec())
num_episodes = tf_metrics.NumberOfEpisodes()
env_steps = tf_metrics.EnvironmentSteps()
observers = [num_episodes, env_steps]
driver = dynamic_episode_driver.DynamicEpisodeDriver(
tf_env, tf_policy, observers, num_episodes=2)
# Initial driver.run will reset the environment and initialize the policy.
final_time_step, policy_state = driver.run()
print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
final_time_step TimeStep( {'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>, 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>, 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>, 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy= array([[-0.0367443 , 0.00652178, 0.04001181, -0.00376746]], dtype=float32)>}) Number of Steps: 34 Number of Episodes: 2
# Continue running from previous state
final_time_step, _ = driver.run(final_time_step, policy_state)
print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
final_time_step TimeStep( {'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>, 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>, 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>, 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy= array([[-0.04702466, -0.04836502, 0.01751254, -0.00393545]], dtype=float32)>}) Number of Steps: 63 Number of Episodes: 4