![]() |
![]() |
![]() |
![]() |
總覽
這個端對端逐步解說使用 tf.estimator
API 訓練邏輯迴歸模型。此模型通常用作其他更複雜演算法的基準。
設定
pip install sklearn
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output
from six.moves import urllib
載入鐵達尼號資料集
您將使用鐵達尼號資料集,其目標 (相當病態) 是根據性別、年齡、艙等特徵等預測乘客的存活率。
import tensorflow.compat.v2.feature_column as fc
import tensorflow as tf
# Load dataset.
dftrain = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
dfeval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
y_train = dftrain.pop('survived')
y_eval = dfeval.pop('survived')
探索資料
資料集包含以下特徵
dftrain.head()
dftrain.describe()
訓練集和評估集中分別有 627 個和 264 個範例。
dftrain.shape[0], dfeval.shape[0]
大多數乘客的年齡介於 20 歲到 30 歲之間。
dftrain.age.hist(bins=20)
男性乘客人數約為女性乘客的兩倍。
dftrain.sex.value_counts().plot(kind='barh')
大多數乘客都搭乘「三等艙」。
dftrain['class'].value_counts().plot(kind='barh')
相較於男性,女性的存活機率高出許多。這顯然是模型的預測特徵。
pd.concat([dftrain, y_train], axis=1).groupby('sex').survived.mean().plot(kind='barh').set_xlabel('% survive')
模型特徵工程
Estimators 使用名為特徵資料欄的系統,說明模型應如何解譯每個原始輸入特徵。Estimator 預期會有數值輸入向量,而特徵資料欄說明模型應如何轉換每個特徵。
選取和設計正確的特徵資料欄集是學習有效模型的關鍵。特徵資料欄可以是原始特徵 dict
中的原始輸入 (基本特徵資料欄),也可以是使用在一個或多個基本資料欄上定義的轉換建立的任何新資料欄 (衍生特徵資料欄)。
線性 Estimator 同時使用數值和類別特徵。特徵資料欄適用於所有 TensorFlow Estimators,其目的是定義用於建模的特徵。此外,它們還提供一些特徵工程功能,例如單熱編碼、正規化和分桶化。
基本特徵資料欄
CATEGORICAL_COLUMNS = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck',
'embark_town', 'alone']
NUMERIC_COLUMNS = ['age', 'fare']
feature_columns = []
for feature_name in CATEGORICAL_COLUMNS:
vocabulary = dftrain[feature_name].unique()
feature_columns.append(tf.feature_column.categorical_column_with_vocabulary_list(feature_name, vocabulary))
for feature_name in NUMERIC_COLUMNS:
feature_columns.append(tf.feature_column.numeric_column(feature_name, dtype=tf.float32))
input_function
指定資料如何轉換為 tf.data.Dataset
,以串流方式饋送輸入管線。tf.data.Dataset
可以接收多個來源,例如資料框架、CSV 格式檔案等等。
def make_input_fn(data_df, label_df, num_epochs=10, shuffle=True, batch_size=32):
def input_function():
ds = tf.data.Dataset.from_tensor_slices((dict(data_df), label_df))
if shuffle:
ds = ds.shuffle(1000)
ds = ds.batch(batch_size).repeat(num_epochs)
return ds
return input_function
train_input_fn = make_input_fn(dftrain, y_train)
eval_input_fn = make_input_fn(dfeval, y_eval, num_epochs=1, shuffle=False)
您可以檢查資料集
ds = make_input_fn(dftrain, y_train, batch_size=10)()
for feature_batch, label_batch in ds.take(1):
print('Some feature keys:', list(feature_batch.keys()))
print()
print('A batch of class:', feature_batch['class'].numpy())
print()
print('A batch of Labels:', label_batch.numpy())
您也可以使用 tf.keras.layers.DenseFeatures
層檢查特定特徵資料欄的結果
age_column = feature_columns[7]
tf.keras.layers.DenseFeatures([age_column])(feature_batch).numpy()
DenseFeatures
只接受密集張量,若要檢查類別資料欄,您需要先將其轉換為指標資料欄
gender_column = feature_columns[0]
tf.keras.layers.DenseFeatures([tf.feature_column.indicator_column(gender_column)])(feature_batch).numpy()
將所有基本特徵新增至模型後,讓我們訓練模型。訓練模型只是使用 tf.estimator
API 的單一指令
linear_est = tf.estimator.LinearClassifier(feature_columns=feature_columns)
linear_est.train(train_input_fn)
result = linear_est.evaluate(eval_input_fn)
clear_output()
print(result)
衍生特徵資料欄
現在您已達到 75% 的準確度。單獨使用每個基本特徵資料欄可能不足以解釋資料。例如,年齡與標籤之間的相關性對於不同性別可能有所不同。因此,如果您只為 gender="Male"
和 gender="Female"
學習單一模型權重,您將無法擷取每個年齡性別組合 (例如,區分 gender="Male"
AND age="30"
AND gender="Male"
AND age="40"
)。
若要瞭解不同特徵組合之間的差異,您可以將交叉特徵資料欄新增至模型 (您也可以在交叉資料欄之前將年齡資料欄分桶化)
age_x_gender = tf.feature_column.crossed_column(['age', 'sex'], hash_bucket_size=100)
將組合特徵新增至模型後,讓我們再次訓練模型
derived_feature_columns = [age_x_gender]
linear_est = tf.estimator.LinearClassifier(feature_columns=feature_columns+derived_feature_columns)
linear_est.train(train_input_fn)
result = linear_est.evaluate(eval_input_fn)
clear_output()
print(result)
現在準確度達到 77.6%,略優於僅在基本特徵中訓練的準確度。您可以嘗試使用更多特徵和轉換,看看是否能做得更好!
現在您可以使用訓練模型,針對評估集中的乘客進行預測。TensorFlow 模型已針對同時對批次或範例集合進行預測進行最佳化。稍早,eval_input_fn
是使用整個評估集定義的。
pred_dicts = list(linear_est.predict(eval_input_fn))
probs = pd.Series([pred['probabilities'][1] for pred in pred_dicts])
probs.plot(kind='hist', bins=20, title='predicted probabilities')
最後,查看結果的接收者操作特徵 (ROC) 曲線,這將讓我們更瞭解真陽性率與假陽性率之間的權衡。
from sklearn.metrics import roc_curve
from matplotlib import pyplot as plt
fpr, tpr, _ = roc_curve(y_eval, probs)
plt.plot(fpr, tpr)
plt.title('ROC curve')
plt.xlabel('false positive rate')
plt.ylabel('true positive rate')
plt.xlim(0,)
plt.ylim(0,)