![]() |
![]() |
![]() |
![]() |
總覽
TensorFlow Estimator 在 TensorFlow 中受到支援,而且可以從新的和現有的 tf.keras
模型建立。本教學課程包含該程序的完整、最簡範例。
設定
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
建立簡單的 Keras 模型。
在 Keras 中,您組裝層來建構模型。模型 (通常) 是層的圖表。最常見的模型類型是層堆疊:tf.keras.Sequential
模型。
若要建構簡單的全連接網路 (即多層感知器)
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(3)
])
編譯模型並取得摘要。
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer='adam')
model.summary()
建立輸入函式
使用 Datasets API 擴展到大型資料集或多裝置訓練。
Estimator 需要控制其輸入管線的建構時間和方式。為了允許這樣做,它們需要「輸入函式」或 input_fn
。Estimator
將在不帶引數的情況下呼叫此函式。input_fn
必須傳回 tf.data.Dataset
。
def input_fn():
split = tfds.Split.TRAIN
dataset = tfds.load('iris', split=split, as_supervised=True)
dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
dataset = dataset.batch(32).repeat()
return dataset
測試您的 input_fn
for features_batch, labels_batch in input_fn().take(1):
print(features_batch)
print(labels_batch)
從 tf.keras 模型建立 Estimator。
可以使用 tf.estimator
API 訓練 tf.keras.Model
,方法是使用 tf.keras.estimator.model_to_estimator
將模型轉換為 tf.estimator.Estimator
物件。
import tempfile
model_dir = tempfile.mkdtemp()
keras_estimator = tf.keras.estimator.model_to_estimator(
keras_model=model, model_dir=model_dir)
訓練和評估 Estimator。
keras_estimator.train(input_fn=input_fn, steps=500)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))