![]() |
![]() |
![]() |
![]() |
本教學課程說明如何使用估算器在 TensorFlow 中解決鳶尾花分類問題。估算器是 TensorFlow 高階表示法的舊版,代表完整的模型。如需更多詳細資訊,請參閱估算器。
首先要知道的事
為了開始使用,您首先需要匯入 TensorFlow 和一些您會需要的程式庫。
import tensorflow as tf
import pandas as pd
資料集
本文件中的範例程式會建立並測試一個模型,該模型根據鳶尾花萼片和花瓣的大小,將鳶尾花分為三個不同的品種。
您將使用鳶尾花資料集訓練模型。鳶尾花資料集包含四個特徵和一個標籤。這四個特徵可識別出個別鳶尾花的下列植物學特徵
- 萼片長度
- 萼片寬度
- 花瓣長度
- 花瓣寬度
根據這些資訊,您可以定義一些有助於剖析資料的常數
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']
接下來,使用 Keras 和 Pandas 下載並剖析鳶尾花資料集。請注意,您會保留不同的資料集以進行訓練和測試。
train_path = tf.keras.utils.get_file(
"iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
"iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
您可以檢查您的資料,以查看您有四個浮點特徵欄和一個 int32 標籤。
train.head()
針對每個資料集,拆分出標籤,模型將會經過訓練來預測這些標籤。
train_y = train.pop('Species')
test_y = test.pop('Species')
# The label column has now been removed from the features.
train.head()
使用估算器進行程式設計的總覽
現在您已設定好資料集,您可以使用 TensorFlow 估算器定義模型。估算器是衍生自 tf.estimator.Estimator
的任何類別。TensorFlow 提供 tf.estimator
(例如,LinearRegressor
) 的集合,以實作常見的 ML 演算法。除此之外,您可以編寫自己的自訂估算器。建議剛開始使用時使用預先建立的估算器。
若要根據預先建立的估算器編寫 TensorFlow 程式,您必須執行下列工作
- 建立一或多個輸入函式。
- 定義模型的特徵欄。
- 例項化估算器,指定特徵欄和各種超參數。
- 在 Estimator 物件上呼叫一或多個方法,並傳遞適當的輸入函式作為資料來源。
讓我們看看如何針對鳶尾花分類實作這些工作。
建立輸入函式
您必須建立輸入函式,以提供用於訓練、評估和預測的資料。
輸入函式是一個函式,會傳回 tf.data.Dataset
物件,此物件會輸出下列雙元素元組
為了示範輸入函式的格式,以下是一個簡單的實作
def input_evaluation_set():
features = {'SepalLength': np.array([6.4, 5.0]),
'SepalWidth': np.array([2.8, 2.3]),
'PetalLength': np.array([5.6, 3.3]),
'PetalWidth': np.array([2.2, 1.0])}
labels = np.array([2, 1])
return features, labels
您的輸入函式可以使用您喜歡的任何方式產生 features
字典和 label
清單。不過,建議使用 TensorFlow 的 Dataset API,它可以剖析各種資料。
Dataset API 可以為您處理許多常見案例。例如,使用 Dataset API,您可以輕鬆地從大型檔案集合中平行讀取記錄,並將其加入單一串流。
為了讓這個範例保持簡單,您將使用 pandas 載入資料,並從這個記憶體內資料建立輸入管道
def input_fn(features, labels, training=True, batch_size=256):
"""An input function for training or evaluating"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle and repeat if you are in training mode.
if training:
dataset = dataset.shuffle(1000).repeat()
return dataset.batch(batch_size)
定義特徵欄
特徵欄是一個物件,用於描述模型應如何使用來自特徵字典的原始輸入資料。當您建立 Estimator 模型時,您會將特徵欄清單傳遞給它,其中描述您希望模型使用的每個特徵。tf.feature_column
模組提供許多選項,可將資料表示到模型。
對於鳶尾花,4 個原始特徵都是數值,因此您將建立特徵欄清單,告知 Estimator 模型將四個特徵中的每一個都表示為 32 位元浮點值。因此,建立特徵欄的程式碼如下
# Feature columns describe how to use the input.
my_feature_columns = []
for key in train.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
特徵欄可能比這裡顯示的更複雜。您可以在本指南中閱讀更多關於特徵欄的資訊。
現在您已描述您希望模型如何表示原始特徵,您可以建立估算器。
例項化估算器
鳶尾花問題是經典的分類問題。幸運的是,TensorFlow 提供數個預先建立的分類器估算器,包括
tf.estimator.DNNClassifier
,適用於執行多類別分類的深度模型。tf.estimator.DNNLinearCombinedClassifier
,適用於廣度和深度模型。tf.estimator.LinearClassifier
,適用於以線性模型為基礎的分類器。
對於鳶尾花問題,tf.estimator.DNNClassifier
似乎是最佳選擇。以下說明如何例項化此估算器
# Build a DNN with 2 hidden layers with 30 and 10 hidden nodes each.
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
# Two hidden layers of 30 and 10 nodes respectively.
hidden_units=[30, 10],
# The model must choose between 3 classes.
n_classes=3)
訓練、評估和預測
現在您有一個 Estimator 物件,您可以呼叫方法來執行下列動作
- 訓練模型。
- 評估已訓練的模型。
- 使用已訓練的模型進行預測。
訓練模型
透過呼叫 Estimator 的 train
方法來訓練模型,如下所示
# Train the Model.
classifier.train(
input_fn=lambda: input_fn(train, train_y, training=True),
steps=5000)
請注意,您將 input_fn
呼叫包裝在 lambda
中,以擷取引數,同時提供一個不帶引數的輸入函式,正如 Estimator 所預期的那樣。steps
引數會告知方法在完成指定數量的訓練步驟後停止訓練。
評估已訓練的模型
現在模型已完成訓練,您可以取得一些關於其效能的統計資料。下列程式碼區塊會評估已訓練模型在測試資料上的準確度
eval_result = classifier.evaluate(
input_fn=lambda: input_fn(test, test_y, training=False))
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
與呼叫 train
方法不同,您沒有將 steps
引數傳遞給 evaluate。eval 的 input_fn
只會產生單一週期的資料。
eval_result
字典也包含 average_loss
(每個樣本的平均損失)、loss
(每個迷你批次的平均損失) 以及估算器的 global_step
值 (它經歷的訓練迭代次數)。
從已訓練的模型進行預測 (推論)
您現在有一個已訓練的模型,可產生良好的評估結果。您現在可以使用已訓練的模型,根據一些未標記的測量值來預測鳶尾花的品種。與訓練和評估一樣,您可以使用單一函式呼叫來進行預測
# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
'SepalLength': [5.1, 5.9, 6.9],
'SepalWidth': [3.3, 3.0, 3.1],
'PetalLength': [1.7, 4.2, 5.4],
'PetalWidth': [0.5, 1.5, 2.1],
}
def input_fn(features, batch_size=256):
"""An input function for prediction."""
# Convert the inputs to a Dataset without labels.
return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)
predictions = classifier.predict(
input_fn=lambda: input_fn(predict_x))
predict
方法會傳回 Python 可迭代物件,為每個範例產生預測結果的字典。下列程式碼會印出一些預測及其機率
for pred_dict, expec in zip(predictions, expected):
class_id = pred_dict['class_ids'][0]
probability = pred_dict['probabilities'][class_id]
print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
SPECIES[class_id], 100 * probability, expec))