從 Keras 模型建立 Estimator

在 TensorFlow.org 上檢視 在 Google Colab 中執行 在 GitHub 上檢視原始碼 下載筆記本

總覽

TensorFlow Estimator 在 TensorFlow 中受到支援,而且可以從新的和現有的 tf.keras 模型建立。本教學課程包含該程序的完整、最簡範例。

設定

import tensorflow as tf

import numpy as np
import tensorflow_datasets as tfds

建立簡單的 Keras 模型。

在 Keras 中,您組裝來建構模型。模型 (通常) 是層的圖表。最常見的模型類型是層堆疊:tf.keras.Sequential 模型。

若要建構簡單的全連接網路 (即多層感知器)

model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(3)
])

編譯模型並取得摘要。

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              optimizer='adam')
model.summary()

建立輸入函式

使用 Datasets API 擴展到大型資料集或多裝置訓練。

Estimator 需要控制其輸入管線的建構時間和方式。為了允許這樣做,它們需要「輸入函式」或 input_fnEstimator 將在不帶引數的情況下呼叫此函式。input_fn 必須傳回 tf.data.Dataset

def input_fn():
  split = tfds.Split.TRAIN
  dataset = tfds.load('iris', split=split, as_supervised=True)
  dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
  dataset = dataset.batch(32).repeat()
  return dataset

測試您的 input_fn

for features_batch, labels_batch in input_fn().take(1):
  print(features_batch)
  print(labels_batch)

從 tf.keras 模型建立 Estimator。

可以使用 tf.estimator API 訓練 tf.keras.Model,方法是使用 tf.keras.estimator.model_to_estimator 將模型轉換為 tf.estimator.Estimator 物件。

import tempfile
model_dir = tempfile.mkdtemp()
keras_estimator = tf.keras.estimator.model_to_estimator(
    keras_model=model, model_dir=model_dir)

訓練和評估 Estimator。

keras_estimator.train(input_fn=input_fn, steps=500)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))