不平衡資料分類

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

本教學課程示範如何分類高度不平衡的資料集,其中一個類別的範例數量遠遠超過另一個類別的範例數量。您將使用 Kaggle 上託管的「信用卡詐欺偵測」資料集。目標是從總共 284,807 筆交易中偵測出僅僅 492 筆詐欺交易。您將使用 Keras 定義模型,並使用類別權重來協助模型從不平衡的資料中學習。 .

本教學課程包含完整程式碼,可執行下列操作:

  • 使用 Pandas 載入 CSV 檔案。
  • 建立訓練、驗證和測試集。
  • 使用 Keras 定義和訓練模型 (包括設定類別權重)。
  • 使用各種指標評估模型 (包括精確率和召回率)。
  • 為機率分類器選取閾值,以取得確定性分類器。
  • 嘗試類別加權建模和過度取樣,並進行比較。

設定

import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
2024-01-17 02:20:29.309180: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-17 02:20:29.309224: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-17 02:20:29.310677: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

資料處理與探索

下載 Kaggle 信用卡詐欺資料集

Pandas 是一個 Python 程式庫,其中包含許多實用的工具,可用於載入和處理結構化資料。它可用於將 CSV 下載到 Pandas DataFrame 中。

file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()

檢查類別標籤不平衡情況

我們來看看資料集的不平衡情況

neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)

這顯示了正樣本的小部分比例。

清理、分割和正規化資料

原始資料有一些問題。首先,TimeAmount 欄位的變異性太高,無法直接使用。捨棄 Time 欄位 (因為其意義不明),並對 Amount 欄位取對數以縮小其範圍。

cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps = 0.001 # 0 => 0.1¢
cleaned_df['Log Amount'] = np.log(cleaned_df.pop('Amount')+eps)

將資料集分割為訓練、驗證和測試集。驗證集在模型擬合期間用於評估損失和任何指標,但模型不使用此資料進行擬合。測試集在訓練階段完全未使用,僅在最後用於評估模型在新資料上的泛化能力。這對於不平衡資料集尤其重要,因為訓練資料的不足會導致過度擬合成為重大問題。

# Use a utility from sklearn to split and shuffle your dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)

我們檢查這三個集合中類別的分布是否大致相同。

print(f'Average class probability in training set:   {train_labels.mean():.4f}')
print(f'Average class probability in validation set: {val_labels.mean():.4f}')
print(f'Average class probability in test set:       {test_labels.mean():.4f}')
Average class probability in training set:   0.0016
Average class probability in validation set: 0.0018
Average class probability in test set:       0.0019

考量到正標籤的數量很少,這似乎是合理的。

使用 sklearn StandardScaler 正規化輸入特徵。這會將平均值設為 0,標準差設為 1。

scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

查看資料分布

接下來,比較正樣本和負樣本在幾個特徵上的分布。此時要問自己的好問題是:

  • 這些分布合理嗎?
    • 是的。您已正規化輸入,而且這些輸入大多集中在 +/- 2 範圍內。
  • 您能看出分布之間的差異嗎?
    • 是的,正樣本包含極端值的比率高出許多。
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train_df.columns)

sns.jointplot(x=pos_df['V5'], y=pos_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(x=neg_df['V5'], y=neg_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")

png

png

定義模型和指標

定義一個函式,以建立一個簡單的神經網路,其中包含密集連接的隱藏層、用於減少過度擬合的 dropout 層,以及傳回交易是否為詐欺機率的輸出 sigmoid 層。

METRICS = [
      keras.metrics.BinaryCrossentropy(name='cross entropy'),  # same as model's loss
      keras.metrics.MeanSquaredError(name='Brier score'),
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
      keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]

def make_model(metrics=METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(learning_rate=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model

瞭解實用指標

請注意,上方定義了一些指標,這些指標可以由模型計算,在評估效能時會有所幫助。這些指標可以分為三組。

機率預測的指標

當我們使用交叉熵作為損失函數訓練網路時,網路完全能夠預測類別機率,也就是說,它是一個機率分類器。事實上,評估機率預測的良好指標是適當評分規則。它們的關鍵特性是預測真實機率是最佳的。我們提供兩個著名的範例:

  • 交叉熵,也稱為對數損失
  • 均方誤差,也稱為 Brier 分數

確定性 0/1 預測的指標

最後,人們通常想要預測類別標籤,0 或 1,非詐欺詐欺。這稱為確定性分類器。若要從我們的機率分類器取得標籤預測,需要選擇機率閾值 \(t\)。預設值是,如果預測機率大於 \(t=50\%\),則預測標籤 1 (詐欺),且所有下列指標都隱含地使用此預設值。

  • 偽陰性偽陽性錯誤分類的樣本
  • 真陰性真陽性正確分類的樣本
  • 準確度是正確分類的範例百分比 > \(\frac{\text{true samples} }{\text{total samples} }\)
  • 精確率是正確分類的預測陽性百分比 > \(\frac{\text{true positives} }{\text{true positives + false positives} }\)
  • 召回率是正確分類的實際陽性百分比 > \(\frac{\text{true positives} }{\text{true positives + false negatives} }\)

其他指標

下列指標會考量所有可能的閾值 \(t\) 選擇。

  • AUC 指的是接收者操作特徵曲線 (ROC-AUC) 的曲線下面積。此指標等於分類器將隨機正樣本的排名高於隨機負樣本的機率。
  • AUPRC 指的是精確率-召回率曲線的曲線下面積。此指標會計算不同機率閾值的精確率-召回率配對。

深入閱讀

基準模型

建構模型

現在使用先前定義的函式建立和訓練模型。請注意,模型使用大於預設值的批次大小 2048 進行擬合,這很重要,可確保每個批次都有相當高的機率包含一些正樣本。如果批次大小太小,則可能沒有任何詐欺交易可供學習。

EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_prc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 16)                480       
                                                                 
 dropout (Dropout)           (None, 16)                0         
                                                                 
 dense_1 (Dense)             (None, 1)                 17        
                                                                 
=================================================================
Total params: 497 (1.94 KB)
Trainable params: 497 (1.94 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

測試執行模型

model.predict(train_features[:10])
1/1 [==============================] - 0s 471ms/step
array([[0.16263928],
       [0.35204744],
       [0.19377157],
       [0.72603256],
       [0.30116165],
       [0.25605297],
       [0.66053736],
       [0.31973222],
       [0.25077152],
       [0.26151225]], dtype=float32)

選用:設定正確的初始偏差。

這些初始猜測不太好。您知道資料集是不平衡的。設定輸出層的偏差以反映此情況,請參閱「神經網路訓練食譜:「良好初始化」」。這有助於初始收斂。

使用預設偏差初始化時,損失應約為 math.log(2) = 0.69314

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.4088

要設定的正確偏差可以從下列公式推導而得:

\[ p_0 = pos/(pos + neg) = 1/(1+e^{-b_0}) \]

\[ b_0 = -log_e(1/p_0 - 1) \]

\[ b_0 = log_e(pos/neg)\]

initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])

將其設定為初始偏差,模型將提供更合理的初始猜測。

應接近:pos/total = 0.0018

model = make_model(output_bias=initial_bias)
model.predict(train_features[:10])
1/1 [==============================] - 0s 75ms/step
array([[0.00135984],
       [0.00134607],
       [0.00213977],
       [0.01406598],
       [0.0021732 ],
       [0.00640495],
       [0.00814889],
       [0.00254694],
       [0.00572464],
       [0.00216844]], dtype=float32)

使用此初始化時,初始損失應約為

\[-p_0log(p_0)-(1-p_0)log(1-p_0) = 0.01317\]

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0087

此初始損失約為使用簡易初始化時的 1/50。

這樣一來,模型就不需要花費前幾個週期來學習正樣本不太可能出現。這也讓訓練期間更容易閱讀損失圖。

檢查點初始權重

為了讓各種訓練執行更具可比性,請將此初始模型的權重保留在檢查點檔案中,並在訓練前將其載入每個模型。

initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')
model.save_weights(initial_weights)

確認偏差修正有所幫助

在繼續之前,快速確認仔細的偏差初始化確實有幫助。

在有無此仔細初始化的情況下,將模型訓練 20 個週期,並比較損失。

model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705458046.535087   10301 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
def plot_loss(history, label, n):
  # Use a log scale on y-axis to show the wide range of values.
  plt.semilogy(history.epoch, history.history['loss'],
               color=colors[n], label='Train ' + label)
  plt.semilogy(history.epoch, history.history['val_loss'],
               color=colors[n], label='Val ' + label,
               linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.legend()
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)

png

上圖清楚表明:就驗證損失而言,在此問題上,此仔細初始化具有明顯優勢。

訓練模型

model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels))
Epoch 1/100
90/90 [==============================] - 2s 11ms/step - loss: 0.0109 - cross entropy: 0.0092 - Brier score: 0.0013 - tp: 179.0000 - fp: 128.0000 - tn: 227336.0000 - fn: 202.0000 - accuracy: 0.9986 - precision: 0.5831 - recall: 0.4698 - auc: 0.8759 - prc: 0.4240 - val_loss: 0.0053 - val_cross entropy: 0.0053 - val_Brier score: 7.6563e-04 - val_tp: 44.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 38.0000 - val_accuracy: 0.9991 - val_precision: 0.8980 - val_recall: 0.5366 - val_auc: 0.9188 - val_prc: 0.7535
Epoch 2/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0070 - cross entropy: 0.0070 - Brier score: 9.7767e-04 - tp: 137.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 162.0000 - accuracy: 0.9989 - precision: 0.8155 - recall: 0.4582 - auc: 0.8800 - prc: 0.5341 - val_loss: 0.0044 - val_cross entropy: 0.0044 - val_Brier score: 6.3545e-04 - val_tp: 54.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 28.0000 - val_accuracy: 0.9992 - val_precision: 0.8852 - val_recall: 0.6585 - val_auc: 0.9263 - val_prc: 0.7737
Epoch 3/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0061 - cross entropy: 0.0061 - Brier score: 8.9642e-04 - tp: 146.0000 - fp: 33.0000 - tn: 181944.0000 - fn: 153.0000 - accuracy: 0.9990 - precision: 0.8156 - recall: 0.4883 - auc: 0.9033 - prc: 0.5771 - val_loss: 0.0040 - val_cross entropy: 0.0040 - val_Brier score: 6.0828e-04 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 27.0000 - val_accuracy: 0.9993 - val_precision: 0.9016 - val_recall: 0.6707 - val_auc: 0.9266 - val_prc: 0.7869
Epoch 4/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0057 - cross entropy: 0.0057 - Brier score: 8.9241e-04 - tp: 147.0000 - fp: 30.0000 - tn: 181947.0000 - fn: 152.0000 - accuracy: 0.9990 - precision: 0.8305 - recall: 0.4916 - auc: 0.9045 - prc: 0.6121 - val_loss: 0.0037 - val_cross entropy: 0.0037 - val_Brier score: 5.6512e-04 - val_tp: 58.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.8923 - val_recall: 0.7073 - val_auc: 0.9327 - val_prc: 0.7996
Epoch 5/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0050 - cross entropy: 0.0050 - Brier score: 8.0944e-04 - tp: 163.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8402 - recall: 0.5452 - auc: 0.9091 - prc: 0.6557 - val_loss: 0.0035 - val_cross entropy: 0.0035 - val_Brier score: 5.4862e-04 - val_tp: 58.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.8923 - val_recall: 0.7073 - val_auc: 0.9327 - val_prc: 0.8041
Epoch 6/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0046 - cross entropy: 0.0046 - Brier score: 7.5796e-04 - tp: 168.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 131.0000 - accuracy: 0.9991 - precision: 0.8615 - recall: 0.5619 - auc: 0.9214 - prc: 0.6995 - val_loss: 0.0034 - val_cross entropy: 0.0034 - val_Brier score: 5.3008e-04 - val_tp: 60.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.8955 - val_recall: 0.7317 - val_auc: 0.9388 - val_prc: 0.8149
Epoch 7/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0045 - cross entropy: 0.0045 - Brier score: 7.0728e-04 - tp: 183.0000 - fp: 33.0000 - tn: 181944.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8472 - recall: 0.6120 - auc: 0.9133 - prc: 0.6901 - val_loss: 0.0033 - val_cross entropy: 0.0033 - val_Brier score: 5.3596e-04 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.9062 - val_recall: 0.7073 - val_auc: 0.9388 - val_prc: 0.8227
Epoch 8/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0048 - cross entropy: 0.0048 - Brier score: 8.0575e-04 - tp: 169.0000 - fp: 35.0000 - tn: 181942.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8284 - recall: 0.5652 - auc: 0.9183 - prc: 0.6610 - val_loss: 0.0032 - val_cross entropy: 0.0032 - val_Brier score: 5.4781e-04 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.9062 - val_recall: 0.7073 - val_auc: 0.9389 - val_prc: 0.8321
Epoch 9/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0043 - cross entropy: 0.0043 - Brier score: 7.4602e-04 - tp: 170.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 129.0000 - accuracy: 0.9991 - precision: 0.8629 - recall: 0.5686 - auc: 0.9186 - prc: 0.7075 - val_loss: 0.0031 - val_cross entropy: 0.0031 - val_Brier score: 5.1218e-04 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.9062 - val_recall: 0.7073 - val_auc: 0.9388 - val_prc: 0.8314
Epoch 10/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0040 - cross entropy: 0.0040 - Brier score: 6.7102e-04 - tp: 178.0000 - fp: 25.0000 - tn: 181952.0000 - fn: 121.0000 - accuracy: 0.9992 - precision: 0.8768 - recall: 0.5953 - auc: 0.9203 - prc: 0.7351 - val_loss: 0.0030 - val_cross entropy: 0.0030 - val_Brier score: 4.8812e-04 - val_tp: 65.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9155 - val_recall: 0.7927 - val_auc: 0.9388 - val_prc: 0.8293
Epoch 11/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 6.3325e-04 - tp: 191.0000 - fp: 25.0000 - tn: 181952.0000 - fn: 108.0000 - accuracy: 0.9993 - precision: 0.8843 - recall: 0.6388 - auc: 0.9170 - prc: 0.7323 - val_loss: 0.0030 - val_cross entropy: 0.0030 - val_Brier score: 4.8228e-04 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9167 - val_recall: 0.8049 - val_auc: 0.9388 - val_prc: 0.8301
Epoch 12/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0042 - cross entropy: 0.0042 - Brier score: 7.6081e-04 - tp: 173.0000 - fp: 35.0000 - tn: 181942.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8317 - recall: 0.5786 - auc: 0.9254 - prc: 0.7097 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 4.7943e-04 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9167 - val_recall: 0.8049 - val_auc: 0.9388 - val_prc: 0.8330
Epoch 13/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0043 - cross entropy: 0.0043 - Brier score: 7.4700e-04 - tp: 175.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 124.0000 - accuracy: 0.9992 - precision: 0.8578 - recall: 0.5853 - auc: 0.9238 - prc: 0.6897 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 4.7884e-04 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9167 - val_recall: 0.8049 - val_auc: 0.9388 - val_prc: 0.8350
Epoch 14/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0040 - cross entropy: 0.0040 - Brier score: 7.1931e-04 - tp: 177.0000 - fp: 30.0000 - tn: 181947.0000 - fn: 122.0000 - accuracy: 0.9992 - precision: 0.8551 - recall: 0.5920 - auc: 0.9171 - prc: 0.7144 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 4.8724e-04 - val_tp: 64.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9275 - val_recall: 0.7805 - val_auc: 0.9388 - val_prc: 0.8409
Epoch 15/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0042 - cross entropy: 0.0042 - Brier score: 7.5652e-04 - tp: 167.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 132.0000 - accuracy: 0.9991 - precision: 0.8608 - recall: 0.5585 - auc: 0.9238 - prc: 0.6964 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7200e-04 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9167 - val_recall: 0.8049 - val_auc: 0.9388 - val_prc: 0.8410
Epoch 16/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0039 - cross entropy: 0.0039 - Brier score: 7.1767e-04 - tp: 177.0000 - fp: 34.0000 - tn: 181943.0000 - fn: 122.0000 - accuracy: 0.9991 - precision: 0.8389 - recall: 0.5920 - auc: 0.9239 - prc: 0.7223 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6891e-04 - val_tp: 65.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9155 - val_recall: 0.7927 - val_auc: 0.9388 - val_prc: 0.8418
Epoch 17/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0041 - cross entropy: 0.0041 - Brier score: 7.5757e-04 - tp: 166.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8601 - recall: 0.5552 - auc: 0.9255 - prc: 0.7017 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6881e-04 - val_tp: 64.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9143 - val_recall: 0.7805 - val_auc: 0.9388 - val_prc: 0.8419
Epoch 18/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 6.7869e-04 - tp: 185.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 114.0000 - accuracy: 0.9992 - precision: 0.8685 - recall: 0.6187 - auc: 0.9289 - prc: 0.7328 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.5812e-04 - val_tp: 67.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9178 - val_recall: 0.8171 - val_auc: 0.9449 - val_prc: 0.8473
Epoch 19/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0039 - cross entropy: 0.0039 - Brier score: 6.9306e-04 - tp: 184.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 115.0000 - accuracy: 0.9992 - precision: 0.8558 - recall: 0.6154 - auc: 0.9222 - prc: 0.7129 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7472e-04 - val_tp: 64.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9275 - val_recall: 0.7805 - val_auc: 0.9389 - val_prc: 0.8439
Epoch 20/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0037 - cross entropy: 0.0037 - Brier score: 6.5706e-04 - tp: 191.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 108.0000 - accuracy: 0.9992 - precision: 0.8604 - recall: 0.6388 - auc: 0.9240 - prc: 0.7368 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8355e-04 - val_tp: 60.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.9375 - val_recall: 0.7317 - val_auc: 0.9388 - val_prc: 0.8442
Epoch 21/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0039 - cross entropy: 0.0039 - Brier score: 7.2760e-04 - tp: 180.0000 - fp: 33.0000 - tn: 181944.0000 - fn: 119.0000 - accuracy: 0.9992 - precision: 0.8451 - recall: 0.6020 - auc: 0.9223 - prc: 0.7170 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 4.9822e-04 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9365 - val_recall: 0.7195 - val_auc: 0.9388 - val_prc: 0.8441
Epoch 22/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.5225e-04 - tp: 181.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 118.0000 - accuracy: 0.9992 - precision: 0.8660 - recall: 0.6054 - auc: 0.9273 - prc: 0.7439 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8637e-04 - val_tp: 60.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.9375 - val_recall: 0.7317 - val_auc: 0.9388 - val_prc: 0.8438
Epoch 23/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0037 - cross entropy: 0.0037 - Brier score: 6.5348e-04 - tp: 186.0000 - fp: 26.0000 - tn: 181951.0000 - fn: 113.0000 - accuracy: 0.9992 - precision: 0.8774 - recall: 0.6221 - auc: 0.9355 - prc: 0.7402 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6483e-04 - val_tp: 65.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9155 - val_recall: 0.7927 - val_auc: 0.9388 - val_prc: 0.8427
Epoch 24/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0037 - cross entropy: 0.0037 - Brier score: 6.7939e-04 - tp: 193.0000 - fp: 35.0000 - tn: 181942.0000 - fn: 106.0000 - accuracy: 0.9992 - precision: 0.8465 - recall: 0.6455 - auc: 0.9340 - prc: 0.7279 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 5.1275e-04 - val_tp: 58.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9355 - val_recall: 0.7073 - val_auc: 0.9449 - val_prc: 0.8509
Epoch 25/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.7560e-04 - tp: 180.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 119.0000 - accuracy: 0.9992 - precision: 0.8654 - recall: 0.6020 - auc: 0.9290 - prc: 0.7396 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8990e-04 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9365 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8503
Epoch 26/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.4978e-04 - tp: 188.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 111.0000 - accuracy: 0.9992 - precision: 0.8704 - recall: 0.6288 - auc: 0.9307 - prc: 0.7594 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7567e-04 - val_tp: 63.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9265 - val_recall: 0.7683 - val_auc: 0.9388 - val_prc: 0.8439
Epoch 27/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 7.1788e-04 - tp: 183.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8632 - recall: 0.6120 - auc: 0.9289 - prc: 0.7194 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6391e-04 - val_tp: 65.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9155 - val_recall: 0.7927 - val_auc: 0.9388 - val_prc: 0.8454
Epoch 28/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.2664e-04 - tp: 200.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 99.0000 - accuracy: 0.9993 - precision: 0.8734 - recall: 0.6689 - auc: 0.9306 - prc: 0.7426 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 5.1824e-04 - val_tp: 58.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9355 - val_recall: 0.7073 - val_auc: 0.9388 - val_prc: 0.8435
Epoch 29/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 6.9012e-04 - tp: 185.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 114.0000 - accuracy: 0.9992 - precision: 0.8685 - recall: 0.6187 - auc: 0.9289 - prc: 0.7251 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7472e-04 - val_tp: 60.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.9231 - val_recall: 0.7317 - val_auc: 0.9449 - val_prc: 0.8510
Epoch 30/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.0173e-04 - tp: 197.0000 - fp: 25.0000 - tn: 181952.0000 - fn: 102.0000 - accuracy: 0.9993 - precision: 0.8874 - recall: 0.6589 - auc: 0.9290 - prc: 0.7333 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8578e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9510 - val_prc: 0.8570
Epoch 31/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.9174e-04 - tp: 187.0000 - fp: 35.0000 - tn: 181942.0000 - fn: 112.0000 - accuracy: 0.9992 - precision: 0.8423 - recall: 0.6254 - auc: 0.9373 - prc: 0.7320 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 5.2550e-04 - val_tp: 58.0000 - val_fp: 3.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9508 - val_recall: 0.7073 - val_auc: 0.9510 - val_prc: 0.8546
Epoch 32/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.6040e-04 - tp: 183.0000 - fp: 21.0000 - tn: 181956.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8971 - recall: 0.6120 - auc: 0.9356 - prc: 0.7430 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.9123e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9510 - val_prc: 0.8581
Epoch 33/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.3240e-04 - tp: 198.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 101.0000 - accuracy: 0.9993 - precision: 0.8800 - recall: 0.6622 - auc: 0.9339 - prc: 0.7473 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 5.0220e-04 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9206 - val_recall: 0.7073 - val_auc: 0.9510 - val_prc: 0.8544
Epoch 34/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0034 - cross entropy: 0.0034 - Brier score: 6.3294e-04 - tp: 193.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 106.0000 - accuracy: 0.9993 - precision: 0.8694 - recall: 0.6455 - auc: 0.9373 - prc: 0.7536 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8504e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8489
Epoch 35/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.8906e-04 - tp: 184.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 115.0000 - accuracy: 0.9992 - precision: 0.8638 - recall: 0.6154 - auc: 0.9239 - prc: 0.7403 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.9829e-04 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9206 - val_recall: 0.7073 - val_auc: 0.9449 - val_prc: 0.8482
Epoch 36/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.5897e-04 - tp: 193.0000 - fp: 30.0000 - tn: 181947.0000 - fn: 106.0000 - accuracy: 0.9993 - precision: 0.8655 - recall: 0.6455 - auc: 0.9340 - prc: 0.7307 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.9601e-04 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9206 - val_recall: 0.7073 - val_auc: 0.9449 - val_prc: 0.8474
Epoch 37/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 7.0205e-04 - tp: 184.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 115.0000 - accuracy: 0.9992 - precision: 0.8558 - recall: 0.6154 - auc: 0.9373 - prc: 0.7124 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.9088e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8476
Epoch 38/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0033 - cross entropy: 0.0033 - Brier score: 6.4066e-04 - tp: 195.0000 - fp: 26.0000 - tn: 181951.0000 - fn: 104.0000 - accuracy: 0.9993 - precision: 0.8824 - recall: 0.6522 - auc: 0.9374 - prc: 0.7656 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8218e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8483
Epoch 39/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0034 - cross entropy: 0.0034 - Brier score: 6.2081e-04 - tp: 195.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 104.0000 - accuracy: 0.9993 - precision: 0.8744 - recall: 0.6522 - auc: 0.9274 - prc: 0.7673 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7334e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8511
Epoch 40/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0033 - cross entropy: 0.0033 - Brier score: 6.0397e-04 - tp: 202.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 97.0000 - accuracy: 0.9993 - precision: 0.8783 - recall: 0.6756 - auc: 0.9358 - prc: 0.7739 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7153e-04 - val_tp: 62.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 20.0000 - val_accuracy: 0.9995 - val_precision: 0.9394 - val_recall: 0.7561 - val_auc: 0.9449 - val_prc: 0.8499
Epoch 41/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.6866e-04 - tp: 186.0000 - fp: 25.0000 - tn: 181952.0000 - fn: 113.0000 - accuracy: 0.9992 - precision: 0.8815 - recall: 0.6221 - auc: 0.9407 - prc: 0.7539 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6169e-04 - val_tp: 66.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9296 - val_recall: 0.8049 - val_auc: 0.9510 - val_prc: 0.8571
Epoch 42/100
86/90 [===========================>..] - ETA: 0s - loss: 0.0033 - cross entropy: 0.0033 - Brier score: 6.4029e-04 - tp: 188.0000 - fp: 30.0000 - tn: 175806.0000 - fn: 104.0000 - accuracy: 0.9992 - precision: 0.8624 - recall: 0.6438 - auc: 0.9445 - prc: 0.7663Restoring model weights from the end of the best epoch: 32.
90/90 [==============================] - 0s 5ms/step - loss: 0.0033 - cross entropy: 0.0033 - Brier score: 6.3642e-04 - tp: 193.0000 - fp: 32.0000 - tn: 181945.0000 - fn: 106.0000 - accuracy: 0.9992 - precision: 0.8578 - recall: 0.6455 - auc: 0.9441 - prc: 0.7655 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7751e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9510 - val_prc: 0.8563
Epoch 42: early stopping

檢查訓練記錄

在本節中,您將產生模型在訓練和驗證集上的準確度和損失圖。這些圖表有助於檢查過度擬合,您可以在過度擬合與欠擬合教學課程中瞭解更多資訊。

此外,您可以為您在上方建立的任何指標產生這些圖表。偽陰性包含在範例中。

def plot_metrics(history):
  metrics = ['loss', 'prc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()
plot_metrics(baseline_history)

png

評估指標

您可以使用混淆矩陣來摘要實際標籤與預測標籤,其中 X 軸是預測標籤,Y 軸是實際標籤。

train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
90/90 [==============================] - 0s 1ms/step
28/28 [==============================] - 0s 1ms/step
def plot_cm(labels, predictions, threshold=0.5):
  cm = confusion_matrix(labels, predictions > threshold)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(threshold))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))

在測試資料集上評估模型,並顯示上方建立的指標結果。

baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
loss :  0.0038855739403516054
cross entropy :  0.0038855739403516054
Brier score :  0.0006162827485240996
tp :  81.0
fp :  11.0
tn :  56840.0
fn :  30.0
accuracy :  0.9992802143096924
precision :  0.8804348111152649
recall :  0.7297297120094299
auc :  0.9096326231956482
prc :  0.7863917350769043

Legitimate Transactions Detected (True Negatives):  56840
Legitimate Transactions Incorrectly Detected (False Positives):  11
Fraudulent Transactions Missed (False Negatives):  30
Fraudulent Transactions Detected (True Positives):  81
Total Fraudulent Transactions:  111

png

如果模型完美預測所有內容 (在真正的隨機性下不可能實現),這將是一個對角矩陣,其中主對角線以外的值 (表示不正確的預測) 將為零。在此案例中,矩陣顯示您的偽陽性相對較少,表示只有相對少數的合法交易被錯誤地標記。

變更閾值

預設閾值 \(t=50\%\) 對應於偽陰性和偽陽性的成本相等。但是,在詐欺偵測的案例中,您可能會將較高的成本與偽陰性相關聯,而不是偽陽性。這種權衡可能是較佳的,因為偽陰性會讓詐欺交易通過,而偽陽性可能會導致向客戶發送電子郵件,要求他們驗證卡片活動。

透過降低閾值,我們將較高的成本歸因於偽陰性,從而以更多偽陽性為代價,增加遺漏的交易。我們測試 10% 和 1% 的閾值。

plot_cm(test_labels, test_predictions_baseline, threshold=0.1)
plot_cm(test_labels, test_predictions_baseline, threshold=0.01)
Legitimate Transactions Detected (True Negatives):  56834
Legitimate Transactions Incorrectly Detected (False Positives):  17
Fraudulent Transactions Missed (False Negatives):  23
Fraudulent Transactions Detected (True Positives):  88
Total Fraudulent Transactions:  111
Legitimate Transactions Detected (True Negatives):  56806
Legitimate Transactions Incorrectly Detected (False Positives):  45
Fraudulent Transactions Missed (False Negatives):  22
Fraudulent Transactions Detected (True Positives):  89
Total Fraudulent Transactions:  111

png

png

繪製 ROC 曲線

現在繪製 ROC 曲線。此圖表很有用,因為它一目瞭然地顯示模型透過在其完整範圍 (0 到 1) 內調整輸出閾值可達到的效能範圍。因此,每個點都對應於閾值的單一值。

def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right');

png

繪製 PRC 曲線

現在繪製 AUPRC。內插精確率-召回率曲線的曲線下面積,透過繪製分類閾值不同值的 (召回率、精確率) 點取得。PR AUC 可能相當於模型的平均精確率,具體取決於其計算方式。

def plot_prc(name, labels, predictions, **kwargs):
    precision, recall, _ = sklearn.metrics.precision_recall_curve(labels, predictions)

    plt.plot(precision, recall, label=name, linewidth=2, **kwargs)
    plt.xlabel('Precision')
    plt.ylabel('Recall')
    plt.grid(True)
    ax = plt.gca()
    ax.set_aspect('equal')
plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right');

png

精確率看起來相對較高,但召回率和 ROC 曲線 (AUC) 下的面積並未達到您可能期望的程度。分類器在嘗試最大化精確率和召回率時,經常面臨挑戰,這在使用不平衡資料集時尤其如此。務必考量您關心的問題背景中,不同類型錯誤的成本。在此範例中,偽陰性 (遺漏詐欺交易) 可能會產生財務成本,而偽陽性 (交易被錯誤地標記為詐欺) 可能會降低使用者滿意度。

類別權重

計算類別權重

目標是識別詐欺交易,但您沒有太多正樣本可供使用,因此您會希望分類器對少數可用的範例進行高度加權。您可以透過參數傳遞每個類別的 Keras 權重來執行此操作。這些權重會導致模型「更關注」來自代表性不足類別的範例。但是請注意,這不會以任何方式增加資料集的資訊量。最後,使用類別權重或多或少相當於變更輸出偏差或變更閾值。讓我們看看效果如何。

# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50
Weight for class 1: 289.44

使用類別權重訓練模型

現在嘗試使用類別權重重新訓練和評估模型,以查看這如何影響預測。

weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight)
Epoch 1/100
90/90 [==============================] - 2s 11ms/step - loss: 0.9262 - cross entropy: 0.0166 - Brier score: 0.0027 - tp: 233.0000 - fp: 530.0000 - tn: 238298.0000 - fn: 177.0000 - accuracy: 0.9970 - precision: 0.3054 - recall: 0.5683 - auc: 0.8803 - prc: 0.4222 - val_loss: 0.0116 - val_cross entropy: 0.0116 - val_Brier score: 0.0011 - val_tp: 67.0000 - val_fp: 35.0000 - val_tn: 45452.0000 - val_fn: 15.0000 - val_accuracy: 0.9989 - val_precision: 0.6569 - val_recall: 0.8171 - val_auc: 0.9519 - val_prc: 0.7255
Epoch 2/100
90/90 [==============================] - 0s 5ms/step - loss: 0.6152 - cross entropy: 0.0319 - Brier score: 0.0059 - tp: 202.0000 - fp: 1117.0000 - tn: 180860.0000 - fn: 97.0000 - accuracy: 0.9933 - precision: 0.1531 - recall: 0.6756 - auc: 0.9051 - prc: 0.4410 - val_loss: 0.0172 - val_cross entropy: 0.0172 - val_Brier score: 0.0019 - val_tp: 69.0000 - val_fp: 72.0000 - val_tn: 45415.0000 - val_fn: 13.0000 - val_accuracy: 0.9981 - val_precision: 0.4894 - val_recall: 0.8415 - val_auc: 0.9577 - val_prc: 0.7220
Epoch 3/100
90/90 [==============================] - 0s 5ms/step - loss: 0.4397 - cross entropy: 0.0461 - Brier score: 0.0095 - tp: 229.0000 - fp: 1929.0000 - tn: 180048.0000 - fn: 70.0000 - accuracy: 0.9890 - precision: 0.1061 - recall: 0.7659 - auc: 0.9307 - prc: 0.4134 - val_loss: 0.0236 - val_cross entropy: 0.0236 - val_Brier score: 0.0029 - val_tp: 69.0000 - val_fp: 106.0000 - val_tn: 45381.0000 - val_fn: 13.0000 - val_accuracy: 0.9974 - val_precision: 0.3943 - val_recall: 0.8415 - val_auc: 0.9662 - val_prc: 0.7291
Epoch 4/100
90/90 [==============================] - 0s 5ms/step - loss: 0.4155 - cross entropy: 0.0619 - Brier score: 0.0136 - tp: 231.0000 - fp: 2898.0000 - tn: 179079.0000 - fn: 68.0000 - accuracy: 0.9837 - precision: 0.0738 - recall: 0.7726 - auc: 0.9272 - prc: 0.3804 - val_loss: 0.0319 - val_cross entropy: 0.0319 - val_Brier score: 0.0046 - val_tp: 70.0000 - val_fp: 188.0000 - val_tn: 45299.0000 - val_fn: 12.0000 - val_accuracy: 0.9956 - val_precision: 0.2713 - val_recall: 0.8537 - val_auc: 0.9697 - val_prc: 0.7095
Epoch 5/100
90/90 [==============================] - 0s 5ms/step - loss: 0.3247 - cross entropy: 0.0773 - Brier score: 0.0178 - tp: 241.0000 - fp: 3872.0000 - tn: 178105.0000 - fn: 58.0000 - accuracy: 0.9784 - precision: 0.0586 - recall: 0.8060 - auc: 0.9471 - prc: 0.3673 - val_loss: 0.0405 - val_cross entropy: 0.0405 - val_Brier score: 0.0068 - val_tp: 71.0000 - val_fp: 334.0000 - val_tn: 45153.0000 - val_fn: 11.0000 - val_accuracy: 0.9924 - val_precision: 0.1753 - val_recall: 0.8659 - val_auc: 0.9714 - val_prc: 0.6518
Epoch 6/100
90/90 [==============================] - 0s 6ms/step - loss: 0.3481 - cross entropy: 0.0976 - Brier score: 0.0225 - tp: 248.0000 - fp: 4880.0000 - tn: 177097.0000 - fn: 51.0000 - accuracy: 0.9729 - precision: 0.0484 - recall: 0.8294 - auc: 0.9351 - prc: 0.3069 - val_loss: 0.0494 - val_cross entropy: 0.0494 - val_Brier score: 0.0093 - val_tp: 73.0000 - val_fp: 511.0000 - val_tn: 44976.0000 - val_fn: 9.0000 - val_accuracy: 0.9886 - val_precision: 0.1250 - val_recall: 0.8902 - val_auc: 0.9742 - val_prc: 0.6313
Epoch 7/100
90/90 [==============================] - 0s 5ms/step - loss: 0.2719 - cross entropy: 0.1078 - Brier score: 0.0253 - tp: 257.0000 - fp: 5673.0000 - tn: 176304.0000 - fn: 42.0000 - accuracy: 0.9686 - precision: 0.0433 - recall: 0.8595 - auc: 0.9564 - prc: 0.2894 - val_loss: 0.0565 - val_cross entropy: 0.0565 - val_Brier score: 0.0112 - val_tp: 73.0000 - val_fp: 633.0000 - val_tn: 44854.0000 - val_fn: 9.0000 - val_accuracy: 0.9859 - val_precision: 0.1034 - val_recall: 0.8902 - val_auc: 0.9757 - val_prc: 0.6267
Epoch 8/100
90/90 [==============================] - 0s 6ms/step - loss: 0.2623 - cross entropy: 0.1179 - Brier score: 0.0275 - tp: 262.0000 - fp: 6123.0000 - tn: 175854.0000 - fn: 37.0000 - accuracy: 0.9662 - precision: 0.0410 - recall: 0.8763 - auc: 0.9554 - prc: 0.2609 - val_loss: 0.0607 - val_cross entropy: 0.0607 - val_Brier score: 0.0124 - val_tp: 73.0000 - val_fp: 686.0000 - val_tn: 44801.0000 - val_fn: 9.0000 - val_accuracy: 0.9847 - val_precision: 0.0962 - val_recall: 0.8902 - val_auc: 0.9754 - val_prc: 0.6069
Epoch 9/100
90/90 [==============================] - 0s 5ms/step - loss: 0.2915 - cross entropy: 0.1184 - Brier score: 0.0280 - tp: 257.0000 - fp: 6295.0000 - tn: 175682.0000 - fn: 42.0000 - accuracy: 0.9652 - precision: 0.0392 - recall: 0.8595 - auc: 0.9494 - prc: 0.2652 - val_loss: 0.0653 - val_cross entropy: 0.0653 - val_Brier score: 0.0135 - val_tp: 74.0000 - val_fp: 742.0000 - val_tn: 44745.0000 - val_fn: 8.0000 - val_accuracy: 0.9835 - val_precision: 0.0907 - val_recall: 0.9024 - val_auc: 0.9773 - val_prc: 0.5856
Epoch 10/100
90/90 [==============================] - 0s 6ms/step - loss: 0.2632 - cross entropy: 0.1336 - Brier score: 0.0313 - tp: 259.0000 - fp: 6976.0000 - tn: 175001.0000 - fn: 40.0000 - accuracy: 0.9615 - precision: 0.0358 - recall: 0.8662 - auc: 0.9561 - prc: 0.2365 - val_loss: 0.0700 - val_cross entropy: 0.0700 - val_Brier score: 0.0146 - val_tp: 76.0000 - val_fp: 801.0000 - val_tn: 44686.0000 - val_fn: 6.0000 - val_accuracy: 0.9823 - val_precision: 0.0867 - val_recall: 0.9268 - val_auc: 0.9773 - val_prc: 0.5876
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2336 - cross entropy: 0.1282 - Brier score: 0.0299 - tp: 269.0000 - fp: 6690.0000 - tn: 175287.0000 - fn: 30.0000 - accuracy: 0.9631 - precision: 0.0387 - recall: 0.8997 - auc: 0.9586 - prc: 0.2494 - val_loss: 0.0679 - val_cross entropy: 0.0679 - val_Brier score: 0.0140 - val_tp: 76.0000 - val_fp: 757.0000 - val_tn: 44730.0000 - val_fn: 6.0000 - val_accuracy: 0.9833 - val_precision: 0.0912 - val_recall: 0.9268 - val_auc: 0.9777 - val_prc: 0.5891
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2399 - cross entropy: 0.1289 - Brier score: 0.0298 - tp: 265.0000 - fp: 6654.0000 - tn: 175323.0000 - fn: 34.0000 - accuracy: 0.9633 - precision: 0.0383 - recall: 0.8863 - auc: 0.9602 - prc: 0.2601 - val_loss: 0.0684 - val_cross entropy: 0.0684 - val_Brier score: 0.0141 - val_tp: 76.0000 - val_fp: 762.0000 - val_tn: 44725.0000 - val_fn: 6.0000 - val_accuracy: 0.9831 - val_precision: 0.0907 - val_recall: 0.9268 - val_auc: 0.9784 - val_prc: 0.5848
Epoch 13/100
79/90 [=========================>....] - ETA: 0s - loss: 0.2286 - cross entropy: 0.1265 - Brier score: 0.0295 - tp: 237.0000 - fp: 5838.0000 - tn: 155684.0000 - fn: 33.0000 - accuracy: 0.9637 - precision: 0.0390 - recall: 0.8778 - auc: 0.9696 - prc: 0.2645Restoring model weights from the end of the best epoch: 3.
90/90 [==============================] - 1s 6ms/step - loss: 0.2341 - cross entropy: 0.1275 - Brier score: 0.0297 - tp: 262.0000 - fp: 6631.0000 - tn: 175346.0000 - fn: 37.0000 - accuracy: 0.9634 - precision: 0.0380 - recall: 0.8763 - auc: 0.9665 - prc: 0.2538 - val_loss: 0.0757 - val_cross entropy: 0.0757 - val_Brier score: 0.0159 - val_tp: 76.0000 - val_fp: 834.0000 - val_tn: 44653.0000 - val_fn: 6.0000 - val_accuracy: 0.9816 - val_precision: 0.0835 - val_recall: 0.9268 - val_auc: 0.9789 - val_prc: 0.5709
Epoch 13: early stopping

檢查訓練記錄

plot_metrics(weighted_history)

png

評估指標

train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
90/90 [==============================] - 0s 1ms/step
28/28 [==============================] - 0s 1ms/step
weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
loss :  0.024716919288039207
cross entropy :  0.024716919288039207
Brier score :  0.0029473488684743643
tp :  88.0
fp :  134.0
tn :  56717.0
fn :  23.0
accuracy :  0.9972437620162964
precision :  0.3963963985443115
recall :  0.792792797088623
auc :  0.9477326273918152
prc :  0.6732124090194702

Legitimate Transactions Detected (True Negatives):  56717
Legitimate Transactions Incorrectly Detected (False Positives):  134
Fraudulent Transactions Missed (False Negatives):  23
Fraudulent Transactions Detected (True Positives):  88
Total Fraudulent Transactions:  111

png

在此處您可以看到,使用類別權重時,準確度和精確率較低,因為偽陽性較多,但反之,召回率和 AUC 較高,因為模型也找到更多真陽性。儘管準確度較低,但此模型的召回率較高 (且比閾值為 50% 的基準模型識別出更多詐欺交易)。當然,兩種錯誤類型都有成本 (您也不希望透過將太多合法交易標記為詐欺來騷擾使用者)。請仔細考量應用程式中這些不同類型錯誤之間的權衡取捨。

與變更閾值的基準模型相比,類別加權模型明顯較差。基準模型的優越性進一步由較低的測試損失值 (交叉熵和均方誤差) 證實,並且還可以透過一起繪製兩個模型的 ROC 曲線來觀察。

繪製 ROC 曲線

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right');

png

繪製 PRC 曲線

plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_prc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_prc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right');

png

過度取樣

過度取樣少數類別

相關的方法是透過過度取樣少數類別來重新取樣資料集。

pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]

使用 NumPy

您可以透過從正樣本中選擇正確數量的隨機索引來手動平衡資料集。

ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
(181977, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
(363954, 29)

使用 tf.data

如果您使用 tf.data,則產生平衡範例的最簡單方法是從 positivenegative 資料集開始,然後將它們合併。如需更多範例,請參閱 tf.data 指南

BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)

每個資料集都提供 (feature, label) 配對。

for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
Features:
 [ 4.57437149e-03  1.41282803e+00 -1.70738347e+00  7.86145002e-01
  2.34322123e+00 -1.32760854e+00  1.68238195e+00 -7.10272314e-01
  8.18760297e-01 -3.09684905e+00  2.01295966e+00 -3.98984767e+00
  1.02827419e+00 -5.00000000e+00 -1.25820263e+00  1.91494135e+00
  5.00000000e+00  3.32009026e+00 -2.75342824e+00 -8.47588695e-03
 -7.83382558e-01 -1.24259811e+00 -6.45039879e-01 -1.71393384e-02
  1.13211907e+00 -1.52256293e+00 -1.08919872e+00 -1.06657977e+00
 -1.45889491e+00]

Label:  1

使用 tf.data.Dataset.sample_from_datasets 將兩者合併在一起。

resampled_ds = tf.data.Dataset.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
0.50341796875

若要使用此資料集,您需要每個週期的步數。

在此案例中,「週期」的定義不太明確。假設它是看到每個負樣本一次所需的批次數。

resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0

在過度取樣的資料上訓練

現在嘗試使用重新取樣的資料集訓練模型,而不是使用類別權重,以查看這些方法的比較結果。

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks=[early_stopping],
    validation_data=val_ds)
Epoch 1/100
278/278 [==============================] - 8s 22ms/step - loss: 0.3612 - cross entropy: 0.3306 - Brier score: 0.1096 - tp: 258749.0000 - fp: 76958.0000 - tn: 264513.0000 - fn: 26086.0000 - accuracy: 0.8355 - precision: 0.7708 - recall: 0.9084 - auc: 0.9490 - prc: 0.9536 - val_loss: 0.2021 - val_cross entropy: 0.2021 - val_Brier score: 0.0446 - val_tp: 75.0000 - val_fp: 1144.0000 - val_tn: 44343.0000 - val_fn: 7.0000 - val_accuracy: 0.9747 - val_precision: 0.0615 - val_recall: 0.9146 - val_auc: 0.9741 - val_prc: 0.7919
Epoch 2/100
278/278 [==============================] - 5s 20ms/step - loss: 0.1757 - cross entropy: 0.1757 - Brier score: 0.0515 - tp: 262968.0000 - fp: 15885.0000 - tn: 269124.0000 - fn: 21367.0000 - accuracy: 0.9346 - precision: 0.9430 - recall: 0.9249 - auc: 0.9817 - prc: 0.9852 - val_loss: 0.1003 - val_cross entropy: 0.1003 - val_Brier score: 0.0205 - val_tp: 76.0000 - val_fp: 858.0000 - val_tn: 44629.0000 - val_fn: 6.0000 - val_accuracy: 0.9810 - val_precision: 0.0814 - val_recall: 0.9268 - val_auc: 0.9777 - val_prc: 0.7702
Epoch 3/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1358 - cross entropy: 0.1358 - Brier score: 0.0398 - tp: 266700.0000 - fp: 11006.0000 - tn: 273451.0000 - fn: 18187.0000 - accuracy: 0.9487 - precision: 0.9604 - recall: 0.9362 - auc: 0.9891 - prc: 0.9904 - val_loss: 0.0725 - val_cross entropy: 0.0725 - val_Brier score: 0.0158 - val_tp: 76.0000 - val_fp: 790.0000 - val_tn: 44697.0000 - val_fn: 6.0000 - val_accuracy: 0.9825 - val_precision: 0.0878 - val_recall: 0.9268 - val_auc: 0.9766 - val_prc: 0.7553
Epoch 4/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1151 - cross entropy: 0.1151 - Brier score: 0.0341 - tp: 269190.0000 - fp: 9719.0000 - tn: 274441.0000 - fn: 15994.0000 - accuracy: 0.9548 - precision: 0.9652 - recall: 0.9439 - auc: 0.9925 - prc: 0.9930 - val_loss: 0.0596 - val_cross entropy: 0.0596 - val_Brier score: 0.0136 - val_tp: 76.0000 - val_fp: 725.0000 - val_tn: 44762.0000 - val_fn: 6.0000 - val_accuracy: 0.9840 - val_precision: 0.0949 - val_recall: 0.9268 - val_auc: 0.9726 - val_prc: 0.7292
Epoch 5/100
278/278 [==============================] - 6s 20ms/step - loss: 0.1006 - cross entropy: 0.1006 - Brier score: 0.0299 - tp: 270949.0000 - fp: 8853.0000 - tn: 275916.0000 - fn: 13626.0000 - accuracy: 0.9605 - precision: 0.9684 - recall: 0.9521 - auc: 0.9945 - prc: 0.9946 - val_loss: 0.0525 - val_cross entropy: 0.0525 - val_Brier score: 0.0124 - val_tp: 76.0000 - val_fp: 668.0000 - val_tn: 44819.0000 - val_fn: 6.0000 - val_accuracy: 0.9852 - val_precision: 0.1022 - val_recall: 0.9268 - val_auc: 0.9717 - val_prc: 0.7216
Epoch 6/100
278/278 [==============================] - 6s 20ms/step - loss: 0.0904 - cross entropy: 0.0904 - Brier score: 0.0268 - tp: 272681.0000 - fp: 8122.0000 - tn: 276344.0000 - fn: 12197.0000 - accuracy: 0.9643 - precision: 0.9711 - recall: 0.9572 - auc: 0.9958 - prc: 0.9956 - val_loss: 0.0456 - val_cross entropy: 0.0456 - val_Brier score: 0.0108 - val_tp: 76.0000 - val_fp: 576.0000 - val_tn: 44911.0000 - val_fn: 6.0000 - val_accuracy: 0.9872 - val_precision: 0.1166 - val_recall: 0.9268 - val_auc: 0.9737 - val_prc: 0.7304
Epoch 7/100
278/278 [==============================] - 6s 20ms/step - loss: 0.0828 - cross entropy: 0.0828 - Brier score: 0.0244 - tp: 273911.0000 - fp: 7426.0000 - tn: 277008.0000 - fn: 10999.0000 - accuracy: 0.9676 - precision: 0.9736 - recall: 0.9614 - auc: 0.9965 - prc: 0.9963 - val_loss: 0.0408 - val_cross entropy: 0.0408 - val_Brier score: 0.0099 - val_tp: 77.0000 - val_fp: 546.0000 - val_tn: 44941.0000 - val_fn: 5.0000 - val_accuracy: 0.9879 - val_precision: 0.1236 - val_recall: 0.9390 - val_auc: 0.9752 - val_prc: 0.7232
Epoch 8/100
278/278 [==============================] - 5s 20ms/step - loss: 0.0775 - cross entropy: 0.0775 - Brier score: 0.0228 - tp: 274985.0000 - fp: 6904.0000 - tn: 277146.0000 - fn: 10309.0000 - accuracy: 0.9698 - precision: 0.9755 - recall: 0.9639 - auc: 0.9970 - prc: 0.9968 - val_loss: 0.0387 - val_cross entropy: 0.0387 - val_Brier score: 0.0096 - val_tp: 77.0000 - val_fp: 568.0000 - val_tn: 44919.0000 - val_fn: 5.0000 - val_accuracy: 0.9874 - val_precision: 0.1194 - val_recall: 0.9390 - val_auc: 0.9761 - val_prc: 0.7145
Epoch 9/100
278/278 [==============================] - 5s 20ms/step - loss: 0.0743 - cross entropy: 0.0743 - Brier score: 0.0219 - tp: 274086.0000 - fp: 6704.0000 - tn: 278828.0000 - fn: 9726.0000 - accuracy: 0.9711 - precision: 0.9761 - recall: 0.9657 - auc: 0.9971 - prc: 0.9969 - val_loss: 0.0344 - val_cross entropy: 0.0344 - val_Brier score: 0.0085 - val_tp: 76.0000 - val_fp: 492.0000 - val_tn: 44995.0000 - val_fn: 6.0000 - val_accuracy: 0.9891 - val_precision: 0.1338 - val_recall: 0.9268 - val_auc: 0.9767 - val_prc: 0.7147
Epoch 10/100
278/278 [==============================] - 5s 20ms/step - loss: 0.0712 - cross entropy: 0.0712 - Brier score: 0.0211 - tp: 275221.0000 - fp: 6399.0000 - tn: 278199.0000 - fn: 9525.0000 - accuracy: 0.9720 - precision: 0.9773 - recall: 0.9665 - auc: 0.9973 - prc: 0.9970 - val_loss: 0.0311 - val_cross entropy: 0.0311 - val_Brier score: 0.0077 - val_tp: 76.0000 - val_fp: 434.0000 - val_tn: 45053.0000 - val_fn: 6.0000 - val_accuracy: 0.9903 - val_precision: 0.1490 - val_recall: 0.9268 - val_auc: 0.9772 - val_prc: 0.7140
Epoch 11/100
276/278 [============================>.] - ETA: 0s - loss: 0.0695 - cross entropy: 0.0695 - Brier score: 0.0206 - tp: 273841.0000 - fp: 6329.0000 - tn: 275888.0000 - fn: 9190.0000 - accuracy: 0.9725 - precision: 0.9774 - recall: 0.9675 - auc: 0.9973 - prc: 0.9970Restoring model weights from the end of the best epoch: 1.
278/278 [==============================] - 5s 20ms/step - loss: 0.0695 - cross entropy: 0.0695 - Brier score: 0.0206 - tp: 275842.0000 - fp: 6384.0000 - tn: 277849.0000 - fn: 9269.0000 - accuracy: 0.9725 - precision: 0.9774 - recall: 0.9675 - auc: 0.9973 - prc: 0.9970 - val_loss: 0.0302 - val_cross entropy: 0.0302 - val_Brier score: 0.0075 - val_tp: 76.0000 - val_fp: 433.0000 - val_tn: 45054.0000 - val_fn: 6.0000 - val_accuracy: 0.9904 - val_precision: 0.1493 - val_recall: 0.9268 - val_auc: 0.9775 - val_prc: 0.7154
Epoch 11: early stopping

如果訓練過程在每次梯度更新時都考量整個資料集,則此過度取樣基本上與類別加權相同。

但是,當您像這裡一樣以批次方式訓練模型時,過度取樣的資料會提供更平滑的梯度訊號:正樣本不會以較大的權重顯示在一個批次中,而是每次都以較小的權重顯示在許多不同的批次中。

這種更平滑的梯度訊號使模型更易於訓練。

檢查訓練記錄

請注意,此處的指標分布會有所不同,因為訓練資料與驗證和測試資料的分布完全不同。

plot_metrics(resampled_history)

png

重新訓練

由於在平衡資料上訓練更容易,因此上述訓練程序可能會快速過度擬合。

因此,請將訓練週期分解,以便讓 tf.keras.callbacks.EarlyStopping 更精細地控制何時停止訓練。

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch=20,
    epochs=10*EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_ds))
Epoch 1/1000
20/20 [==============================] - 2s 47ms/step - loss: 0.6826 - cross entropy: 0.3390 - Brier score: 0.1176 - tp: 18430.0000 - fp: 14493.0000 - tn: 51299.0000 - fn: 2307.0000 - accuracy: 0.8058 - precision: 0.5598 - recall: 0.8887 - auc: 0.9464 - prc: 0.8794 - val_loss: 1.0004 - val_cross entropy: 1.0004 - val_Brier score: 0.3825 - val_tp: 79.0000 - val_fp: 36434.0000 - val_tn: 9053.0000 - val_fn: 3.0000 - val_accuracy: 0.2004 - val_precision: 0.0022 - val_recall: 0.9634 - val_auc: 0.9261 - val_prc: 0.5749
Epoch 2/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.5885 - cross entropy: 0.5885 - Brier score: 0.2089 - tp: 18504.0000 - fp: 12626.0000 - tn: 7862.0000 - fn: 1968.0000 - accuracy: 0.6437 - precision: 0.5944 - recall: 0.9039 - auc: 0.8752 - prc: 0.9121 - val_loss: 0.8348 - val_cross entropy: 0.8348 - val_Brier score: 0.3117 - val_tp: 79.0000 - val_fp: 28884.0000 - val_tn: 16603.0000 - val_fn: 3.0000 - val_accuracy: 0.3661 - val_precision: 0.0027 - val_recall: 0.9634 - val_auc: 0.9395 - val_prc: 0.6709
Epoch 3/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.5070 - cross entropy: 0.5070 - Brier score: 0.1790 - tp: 18418.0000 - fp: 10425.0000 - tn: 10193.0000 - fn: 1924.0000 - accuracy: 0.6985 - precision: 0.6386 - recall: 0.9054 - auc: 0.8991 - prc: 0.9280 - val_loss: 0.6975 - val_cross entropy: 0.6975 - val_Brier score: 0.2495 - val_tp: 78.0000 - val_fp: 19535.0000 - val_tn: 25952.0000 - val_fn: 4.0000 - val_accuracy: 0.5712 - val_precision: 0.0040 - val_recall: 0.9512 - val_auc: 0.9499 - val_prc: 0.7048
Epoch 4/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.4413 - cross entropy: 0.4413 - Brier score: 0.1530 - tp: 18483.0000 - fp: 8228.0000 - tn: 12349.0000 - fn: 1900.0000 - accuracy: 0.7527 - precision: 0.6920 - recall: 0.9068 - auc: 0.9179 - prc: 0.9406 - val_loss: 0.5893 - val_cross entropy: 0.5893 - val_Brier score: 0.1998 - val_tp: 77.0000 - val_fp: 11782.0000 - val_tn: 33705.0000 - val_fn: 5.0000 - val_accuracy: 0.7413 - val_precision: 0.0065 - val_recall: 0.9390 - val_auc: 0.9552 - val_prc: 0.7246
Epoch 5/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.3914 - cross entropy: 0.3914 - Brier score: 0.1335 - tp: 18615.0000 - fp: 6548.0000 - tn: 13896.0000 - fn: 1901.0000 - accuracy: 0.7937 - precision: 0.7398 - recall: 0.9073 - auc: 0.9304 - prc: 0.9500 - val_loss: 0.5045 - val_cross entropy: 0.5045 - val_Brier score: 0.1613 - val_tp: 77.0000 - val_fp: 7135.0000 - val_tn: 38352.0000 - val_fn: 5.0000 - val_accuracy: 0.8433 - val_precision: 0.0107 - val_recall: 0.9390 - val_auc: 0.9595 - val_prc: 0.7424
Epoch 6/1000
20/20 [==============================] - 0s 26ms/step - loss: 0.3563 - cross entropy: 0.3563 - Brier score: 0.1183 - tp: 18429.0000 - fp: 5050.0000 - tn: 15533.0000 - fn: 1948.0000 - accuracy: 0.8292 - precision: 0.7849 - recall: 0.9044 - auc: 0.9391 - prc: 0.9552 - val_loss: 0.4395 - val_cross entropy: 0.4395 - val_Brier score: 0.1328 - val_tp: 77.0000 - val_fp: 4727.0000 - val_tn: 40760.0000 - val_fn: 5.0000 - val_accuracy: 0.8962 - val_precision: 0.0160 - val_recall: 0.9390 - val_auc: 0.9616 - val_prc: 0.7625
Epoch 7/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.3220 - cross entropy: 0.3220 - Brier score: 0.1047 - tp: 18807.0000 - fp: 4065.0000 - tn: 16241.0000 - fn: 1847.0000 - accuracy: 0.8557 - precision: 0.8223 - recall: 0.9106 - auc: 0.9485 - prc: 0.9631 - val_loss: 0.3867 - val_cross entropy: 0.3867 - val_Brier score: 0.1105 - val_tp: 77.0000 - val_fp: 3192.0000 - val_tn: 42295.0000 - val_fn: 5.0000 - val_accuracy: 0.9298 - val_precision: 0.0236 - val_recall: 0.9390 - val_auc: 0.9635 - val_prc: 0.7711
Epoch 8/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.3012 - cross entropy: 0.3012 - Brier score: 0.0959 - tp: 18607.0000 - fp: 3384.0000 - tn: 17165.0000 - fn: 1804.0000 - accuracy: 0.8733 - precision: 0.8461 - recall: 0.9116 - auc: 0.9545 - prc: 0.9661 - val_loss: 0.3438 - val_cross entropy: 0.3438 - val_Brier score: 0.0932 - val_tp: 77.0000 - val_fp: 2361.0000 - val_tn: 43126.0000 - val_fn: 5.0000 - val_accuracy: 0.9481 - val_precision: 0.0316 - val_recall: 0.9390 - val_auc: 0.9644 - val_prc: 0.7748
Epoch 9/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2795 - cross entropy: 0.2795 - Brier score: 0.0880 - tp: 18636.0000 - fp: 2891.0000 - tn: 17616.0000 - fn: 1817.0000 - accuracy: 0.8851 - precision: 0.8657 - recall: 0.9112 - auc: 0.9589 - prc: 0.9692 - val_loss: 0.3087 - val_cross entropy: 0.3087 - val_Brier score: 0.0799 - val_tp: 76.0000 - val_fp: 1892.0000 - val_tn: 43595.0000 - val_fn: 6.0000 - val_accuracy: 0.9583 - val_precision: 0.0386 - val_recall: 0.9268 - val_auc: 0.9658 - val_prc: 0.7797
Epoch 10/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2620 - cross entropy: 0.2620 - Brier score: 0.0812 - tp: 18743.0000 - fp: 2432.0000 - tn: 17955.0000 - fn: 1830.0000 - accuracy: 0.8959 - precision: 0.8851 - recall: 0.9110 - auc: 0.9625 - prc: 0.9724 - val_loss: 0.2798 - val_cross entropy: 0.2798 - val_Brier score: 0.0695 - val_tp: 76.0000 - val_fp: 1615.0000 - val_tn: 43872.0000 - val_fn: 6.0000 - val_accuracy: 0.9644 - val_precision: 0.0449 - val_recall: 0.9268 - val_auc: 0.9674 - val_prc: 0.7834
Epoch 11/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2480 - cross entropy: 0.2480 - Brier score: 0.0757 - tp: 18645.0000 - fp: 2154.0000 - tn: 18383.0000 - fn: 1778.0000 - accuracy: 0.9040 - precision: 0.8964 - recall: 0.9129 - auc: 0.9668 - prc: 0.9748 - val_loss: 0.2551 - val_cross entropy: 0.2551 - val_Brier score: 0.0611 - val_tp: 76.0000 - val_fp: 1428.0000 - val_tn: 44059.0000 - val_fn: 6.0000 - val_accuracy: 0.9685 - val_precision: 0.0505 - val_recall: 0.9268 - val_auc: 0.9691 - val_prc: 0.7858
Epoch 12/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2368 - cross entropy: 0.2368 - Brier score: 0.0722 - tp: 18706.0000 - fp: 1922.0000 - tn: 18565.0000 - fn: 1767.0000 - accuracy: 0.9099 - precision: 0.9068 - recall: 0.9137 - auc: 0.9682 - prc: 0.9759 - val_loss: 0.2341 - val_cross entropy: 0.2341 - val_Brier score: 0.0543 - val_tp: 75.0000 - val_fp: 1301.0000 - val_tn: 44186.0000 - val_fn: 7.0000 - val_accuracy: 0.9713 - val_precision: 0.0545 - val_recall: 0.9146 - val_auc: 0.9710 - val_prc: 0.7888
Epoch 13/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2223 - cross entropy: 0.2223 - Brier score: 0.0667 - tp: 18874.0000 - fp: 1694.0000 - tn: 18675.0000 - fn: 1717.0000 - accuracy: 0.9167 - precision: 0.9176 - recall: 0.9166 - auc: 0.9720 - prc: 0.9785 - val_loss: 0.2162 - val_cross entropy: 0.2162 - val_Brier score: 0.0488 - val_tp: 75.0000 - val_fp: 1235.0000 - val_tn: 44252.0000 - val_fn: 7.0000 - val_accuracy: 0.9727 - val_precision: 0.0573 - val_recall: 0.9146 - val_auc: 0.9732 - val_prc: 0.7912
Epoch 14/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.2172 - cross entropy: 0.2172 - Brier score: 0.0648 - tp: 18681.0000 - fp: 1627.0000 - tn: 18898.0000 - fn: 1754.0000 - accuracy: 0.9175 - precision: 0.9199 - recall: 0.9142 - auc: 0.9732 - prc: 0.9789 - val_loss: 0.2011 - val_cross entropy: 0.2011 - val_Brier score: 0.0444 - val_tp: 75.0000 - val_fp: 1167.0000 - val_tn: 44320.0000 - val_fn: 7.0000 - val_accuracy: 0.9742 - val_precision: 0.0604 - val_recall: 0.9146 - val_auc: 0.9748 - val_prc: 0.7927
Epoch 15/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2088 - cross entropy: 0.2088 - Brier score: 0.0619 - tp: 18878.0000 - fp: 1484.0000 - tn: 18949.0000 - fn: 1649.0000 - accuracy: 0.9235 - precision: 0.9271 - recall: 0.9197 - auc: 0.9749 - prc: 0.9806 - val_loss: 0.1872 - val_cross entropy: 0.1872 - val_Brier score: 0.0405 - val_tp: 75.0000 - val_fp: 1100.0000 - val_tn: 44387.0000 - val_fn: 7.0000 - val_accuracy: 0.9757 - val_precision: 0.0638 - val_recall: 0.9146 - val_auc: 0.9760 - val_prc: 0.7931
Epoch 16/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2011 - cross entropy: 0.2011 - Brier score: 0.0596 - tp: 18797.0000 - fp: 1439.0000 - tn: 19068.0000 - fn: 1656.0000 - accuracy: 0.9244 - precision: 0.9289 - recall: 0.9190 - auc: 0.9768 - prc: 0.9818 - val_loss: 0.1743 - val_cross entropy: 0.1743 - val_Brier score: 0.0369 - val_tp: 75.0000 - val_fp: 1029.0000 - val_tn: 44458.0000 - val_fn: 7.0000 - val_accuracy: 0.9773 - val_precision: 0.0679 - val_recall: 0.9146 - val_auc: 0.9769 - val_prc: 0.7935
Epoch 17/1000
20/20 [==============================] - 0s 26ms/step - loss: 0.1961 - cross entropy: 0.1961 - Brier score: 0.0575 - tp: 18762.0000 - fp: 1337.0000 - tn: 19238.0000 - fn: 1623.0000 - accuracy: 0.9277 - precision: 0.9335 - recall: 0.9204 - auc: 0.9773 - prc: 0.9821 - val_loss: 0.1636 - val_cross entropy: 0.1636 - val_Brier score: 0.0342 - val_tp: 75.0000 - val_fp: 997.0000 - val_tn: 44490.0000 - val_fn: 7.0000 - val_accuracy: 0.9780 - val_precision: 0.0700 - val_recall: 0.9146 - val_auc: 0.9777 - val_prc: 0.7943
Epoch 18/1000
20/20 [==============================] - 0s 26ms/step - loss: 0.1891 - cross entropy: 0.1891 - Brier score: 0.0554 - tp: 18751.0000 - fp: 1286.0000 - tn: 19292.0000 - fn: 1631.0000 - accuracy: 0.9288 - precision: 0.9358 - recall: 0.9200 - auc: 0.9789 - prc: 0.9833 - val_loss: 0.1544 - val_cross entropy: 0.1544 - val_Brier score: 0.0320 - val_tp: 75.0000 - val_fp: 981.0000 - val_tn: 44506.0000 - val_fn: 7.0000 - val_accuracy: 0.9783 - val_precision: 0.0710 - val_recall: 0.9146 - val_auc: 0.9780 - val_prc: 0.7971
Epoch 19/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1833 - cross entropy: 0.1833 - Brier score: 0.0534 - tp: 18789.0000 - fp: 1144.0000 - tn: 19432.0000 - fn: 1595.0000 - accuracy: 0.9331 - precision: 0.9426 - recall: 0.9218 - auc: 0.9802 - prc: 0.9842 - val_loss: 0.1461 - val_cross entropy: 0.1461 - val_Brier score: 0.0300 - val_tp: 76.0000 - val_fp: 949.0000 - val_tn: 44538.0000 - val_fn: 6.0000 - val_accuracy: 0.9790 - val_precision: 0.0741 - val_recall: 0.9268 - val_auc: 0.9782 - val_prc: 0.7972
Epoch 20/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1775 - cross entropy: 0.1775 - Brier score: 0.0517 - tp: 18845.0000 - fp: 1120.0000 - tn: 19463.0000 - fn: 1532.0000 - accuracy: 0.9353 - precision: 0.9439 - recall: 0.9248 - auc: 0.9814 - prc: 0.9849 - val_loss: 0.1394 - val_cross entropy: 0.1394 - val_Brier score: 0.0287 - val_tp: 76.0000 - val_fp: 969.0000 - val_tn: 44518.0000 - val_fn: 6.0000 - val_accuracy: 0.9786 - val_precision: 0.0727 - val_recall: 0.9268 - val_auc: 0.9788 - val_prc: 0.7971
Epoch 21/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1727 - cross entropy: 0.1727 - Brier score: 0.0506 - tp: 19042.0000 - fp: 1056.0000 - tn: 19310.0000 - fn: 1552.0000 - accuracy: 0.9363 - precision: 0.9475 - recall: 0.9246 - auc: 0.9818 - prc: 0.9855 - val_loss: 0.1331 - val_cross entropy: 0.1331 - val_Brier score: 0.0274 - val_tp: 76.0000 - val_fp: 965.0000 - val_tn: 44522.0000 - val_fn: 6.0000 - val_accuracy: 0.9787 - val_precision: 0.0730 - val_recall: 0.9268 - val_auc: 0.9789 - val_prc: 0.7973
Epoch 22/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1711 - cross entropy: 0.1711 - Brier score: 0.0501 - tp: 19041.0000 - fp: 1102.0000 - tn: 19283.0000 - fn: 1534.0000 - accuracy: 0.9356 - precision: 0.9453 - recall: 0.9254 - auc: 0.9826 - prc: 0.9859 - val_loss: 0.1275 - val_cross entropy: 0.1275 - val_Brier score: 0.0262 - val_tp: 76.0000 - val_fp: 965.0000 - val_tn: 44522.0000 - val_fn: 6.0000 - val_accuracy: 0.9787 - val_precision: 0.0730 - val_recall: 0.9268 - val_auc: 0.9784 - val_prc: 0.7879
Epoch 23/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1657 - cross entropy: 0.1657 - Brier score: 0.0479 - tp: 19074.0000 - fp: 1045.0000 - tn: 19372.0000 - fn: 1469.0000 - accuracy: 0.9386 - precision: 0.9481 - recall: 0.9285 - auc: 0.9838 - prc: 0.9867 - val_loss: 0.1215 - val_cross entropy: 0.1215 - val_Brier score: 0.0249 - val_tp: 76.0000 - val_fp: 939.0000 - val_tn: 44548.0000 - val_fn: 6.0000 - val_accuracy: 0.9793 - val_precision: 0.0749 - val_recall: 0.9268 - val_auc: 0.9785 - val_prc: 0.7882
Epoch 24/1000
20/20 [==============================] - 0s 26ms/step - loss: 0.1631 - cross entropy: 0.1631 - Brier score: 0.0478 - tp: 19006.0000 - fp: 1055.0000 - tn: 19442.0000 - fn: 1457.0000 - accuracy: 0.9387 - precision: 0.9474 - recall: 0.9288 - auc: 0.9839 - prc: 0.9868 - val_loss: 0.1166 - val_cross entropy: 0.1166 - val_Brier score: 0.0239 - val_tp: 76.0000 - val_fp: 924.0000 - val_tn: 44563.0000 - val_fn: 6.0000 - val_accuracy: 0.9796 - val_precision: 0.0760 - val_recall: 0.9268 - val_auc: 0.9780 - val_prc: 0.7886
Epoch 25/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1586 - cross entropy: 0.1586 - Brier score: 0.0464 - tp: 19058.0000 - fp: 971.0000 - tn: 19476.0000 - fn: 1455.0000 - accuracy: 0.9408 - precision: 0.9515 - recall: 0.9291 - auc: 0.9847 - prc: 0.9875 - val_loss: 0.1119 - val_cross entropy: 0.1119 - val_Brier score: 0.0229 - val_tp: 76.0000 - val_fp: 908.0000 - val_tn: 44579.0000 - val_fn: 6.0000 - val_accuracy: 0.9799 - val_precision: 0.0772 - val_recall: 0.9268 - val_auc: 0.9783 - val_prc: 0.7886
Epoch 26/1000
20/20 [==============================] - 0s 26ms/step - loss: 0.1568 - cross entropy: 0.1568 - Brier score: 0.0459 - tp: 18807.0000 - fp: 974.0000 - tn: 19740.0000 - fn: 1439.0000 - accuracy: 0.9411 - precision: 0.9508 - recall: 0.9289 - auc: 0.9851 - prc: 0.9874 - val_loss: 0.1072 - val_cross entropy: 0.1072 - val_Brier score: 0.0219 - val_tp: 76.0000 - val_fp: 881.0000 - val_tn: 44606.0000 - val_fn: 6.0000 - val_accuracy: 0.9805 - val_precision: 0.0794 - val_recall: 0.9268 - val_auc: 0.9779 - val_prc: 0.7889
Epoch 27/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1562 - cross entropy: 0.1562 - Brier score: 0.0457 - tp: 19045.0000 - fp: 1010.0000 - tn: 19477.0000 - fn: 1428.0000 - accuracy: 0.9405 - precision: 0.9496 - recall: 0.9302 - auc: 0.9854 - prc: 0.9876 - val_loss: 0.1032 - val_cross entropy: 0.1032 - val_Brier score: 0.0211 - val_tp: 76.0000 - val_fp: 864.0000 - val_tn: 44623.0000 - val_fn: 6.0000 - val_accuracy: 0.9809 - val_precision: 0.0809 - val_recall: 0.9268 - val_auc: 0.9774 - val_prc: 0.7704
Epoch 28/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1525 - cross entropy: 0.1525 - Brier score: 0.0442 - tp: 19016.0000 - fp: 881.0000 - tn: 19650.0000 - fn: 1413.0000 - accuracy: 0.9440 - precision: 0.9557 - recall: 0.9308 - auc: 0.9862 - prc: 0.9882 - val_loss: 0.0998 - val_cross entropy: 0.0998 - val_Brier score: 0.0205 - val_tp: 76.0000 - val_fp: 866.0000 - val_tn: 44621.0000 - val_fn: 6.0000 - val_accuracy: 0.9809 - val_precision: 0.0807 - val_recall: 0.9268 - val_auc: 0.9778 - val_prc: 0.7706
Epoch 29/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1465 - cross entropy: 0.1465 - Brier score: 0.0429 - tp: 19105.0000 - fp: 852.0000 - tn: 19596.0000 - fn: 1407.0000 - accuracy: 0.9448 - precision: 0.9573 - recall: 0.9314 - auc: 0.9870 - prc: 0.9891 - val_loss: 0.0968 - val_cross entropy: 0.0968 - val_Brier score: 0.0200 - val_tp: 76.0000 - val_fp: 868.0000 - val_tn: 44619.0000 - val_fn: 6.0000 - val_accuracy: 0.9808 - val_precision: 0.0805 - val_recall: 0.9268 - val_auc: 0.9770 - val_prc: 0.7709
Epoch 30/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1465 - cross entropy: 0.1465 - Brier score: 0.0431 - tp: 19112.0000 - fp: 860.0000 - tn: 19584.0000 - fn: 1404.0000 - accuracy: 0.9447 - precision: 0.9569 - recall: 0.9316 - auc: 0.9867 - prc: 0.9888 - val_loss: 0.0941 - val_cross entropy: 0.0941 - val_Brier score: 0.0195 - val_tp: 76.0000 - val_fp: 850.0000 - val_tn: 44637.0000 - val_fn: 6.0000 - val_accuracy: 0.9812 - val_precision: 0.0821 - val_recall: 0.9268 - val_auc: 0.9774 - val_prc: 0.7712
Epoch 31/1000
20/20 [==============================] - ETA: 0s - loss: 0.1436 - cross entropy: 0.1436 - Brier score: 0.0420 - tp: 19077.0000 - fp: 857.0000 - tn: 19655.0000 - fn: 1371.0000 - accuracy: 0.9456 - precision: 0.9570 - recall: 0.9330 - auc: 0.9876 - prc: 0.9893Restoring model weights from the end of the best epoch: 21.
20/20 [==============================] - 0s 25ms/step - loss: 0.1436 - cross entropy: 0.1436 - Brier score: 0.0420 - tp: 19077.0000 - fp: 857.0000 - tn: 19655.0000 - fn: 1371.0000 - accuracy: 0.9456 - precision: 0.9570 - recall: 0.9330 - auc: 0.9876 - prc: 0.9893 - val_loss: 0.0912 - val_cross entropy: 0.0912 - val_Brier score: 0.0189 - val_tp: 76.0000 - val_fp: 826.0000 - val_tn: 44661.0000 - val_fn: 6.0000 - val_accuracy: 0.9817 - val_precision: 0.0843 - val_recall: 0.9268 - val_auc: 0.9767 - val_prc: 0.7622
Epoch 31: early stopping

重新檢查訓練記錄

plot_metrics(resampled_history)

png

評估指標

train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
90/90 [==============================] - 0s 1ms/step
28/28 [==============================] - 0s 1ms/step
resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_resampled)
loss :  0.13269135355949402
cross entropy :  0.13269135355949402
Brier score :  0.02699681930243969
tp :  96.0
fp :  1177.0
tn :  55674.0
fn :  15.0
accuracy :  0.9790737628936768
precision :  0.07541241496801376
recall :  0.8648648858070374
auc :  0.9722627401351929
prc :  0.703483521938324

Legitimate Transactions Detected (True Negatives):  55674
Legitimate Transactions Incorrectly Detected (False Positives):  1177
Fraudulent Transactions Missed (False Negatives):  15
Fraudulent Transactions Detected (True Positives):  96
Total Fraudulent Transactions:  111

png

繪製 ROC 曲線

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')
plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right');

png

繪製 AUPRC 圖表

plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_prc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_prc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_prc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_prc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right');

png

將本教學課程應用於您的問題

不平衡資料分類本質上是個難題,因為可供學習的樣本非常少。您應該始終從資料著手,盡力收集盡可能多的樣本,並仔細思考哪些特徵可能相關,以便模型能充分利用您的少數類別。在某些時候,您的模型可能會難以改進並產生您想要的結果,因此務必牢記您問題的背景,以及不同類型錯誤之間的取捨。