![]() |
![]() |
![]() |
![]() |
![]() |
總覽
公平性指標是一套工具,建構於 TensorFlow Model Analysis (TFMA) 之上,可在產品管線中定期評估公平性指標。TFMA 是一個用於評估 TensorFlow 和非 TensorFlow 機器學習模型的程式庫。它可讓您以分散式方式評估大量資料的模型、計算圖內和其他跨不同資料切片的指標,並在筆記本中將其視覺化。
公平性指標與 TensorFlow Data Validation (TFDV) 和 What-If Tool 搭配封裝。使用公平性指標可讓您
- 評估模型效能,依定義的使用者群組進行切片
- 透過信賴區間和多個閾值的評估,對結果更有信心
- 評估資料集的分布
- 深入探討個別切片,以探索根本原因和改進機會
在本筆記本中,您將使用公平性指標來修正您使用 Civil Comments 資料集訓練之模型中的公平性問題。觀看此影片以取得更多詳細資訊,並瞭解此問題所根據的真實世界情境,這也是建立公平性指標的主要動機之一。
資料集
在本筆記本中,您將使用 Civil Comments 資料集,這是 Civil Comments 平台於 2017 年公開發布約 200 萬則公開留言,以供持續研究之用。這項工作由 Jigsaw 贊助,他們曾在 Kaggle 上舉辦競賽,以協助分類有害留言並盡量減少非預期的模型偏見。
資料集中的每則文字留言都有毒性標籤,如果留言有害,標籤為 1,如果留言無害,標籤為 0。在資料中,部分留言會標記各種身分屬性,包括性別、性傾向、宗教和種族或族裔類別。
設定
安裝 fairness-indicators
和 witwidget
。
pip install -q -U pip==20.2
pip install -q fairness-indicators
pip install -q witwidget
安裝後,您必須重新啟動 Colab 執行階段。從 Colab 選單中選取「執行階段 > 重新啟動執行階段」。
請先重新啟動執行階段,再繼續進行本教學課程的其餘部分。
匯入所有其他必要的程式庫。
import os
import tempfile
import apache_beam as beam
import numpy as np
import pandas as pd
from datetime import datetime
import pprint
from google.protobuf import text_format
import tensorflow_hub as hub
import tensorflow as tf
import tensorflow_model_analysis as tfma
import tensorflow_data_validation as tfdv
from tfx_bsl.tfxio import tensor_adapter
from tfx_bsl.tfxio import tf_example_record
from tensorflow_model_analysis.addons.fairness.post_export_metrics import fairness_indicators
from tensorflow_model_analysis.addons.fairness.view import widget_view
from fairness_indicators.tutorial_utils import util
from witwidget.notebook.visualization import WitConfigBuilder
from witwidget.notebook.visualization import WitWidget
from tensorflow_metadata.proto.v0 import schema_pb2
下載並分析資料
預設情況下,本筆記本會下載此資料集的預先處理版本,但您可以視需要使用原始資料集並重新執行處理步驟。在原始資料集中,每則留言都會標記評分者認為留言對應特定身分的百分比。例如,留言可能會標記如下:{ male: 0.3, female: 1.0, transgender: 0.0, heterosexual: 0.8, homosexual_gay_or_lesbian: 1.0 } 處理步驟會依類別 (性別、性傾向等) 將身分分組,並移除分數低於 0.5 的身分。因此,上述範例會轉換為以下內容:評分者認為留言對應特定身分的百分比。例如,留言會標記如下:{ gender: [female], sexual_orientation: [heterosexual, homosexual_gay_or_lesbian] }
download_original_data = False
if download_original_data:
train_tf_file = tf.keras.utils.get_file('train_tf.tfrecord',
'https://storage.googleapis.com/civil_comments_dataset/train_tf.tfrecord')
validate_tf_file = tf.keras.utils.get_file('validate_tf.tfrecord',
'https://storage.googleapis.com/civil_comments_dataset/validate_tf.tfrecord')
# The identity terms list will be grouped together by their categories
# (see 'IDENTITY_COLUMNS') on threshould 0.5. Only the identity term column,
# text column and label column will be kept after processing.
train_tf_file = util.convert_comments_data(train_tf_file)
validate_tf_file = util.convert_comments_data(validate_tf_file)
else:
train_tf_file = tf.keras.utils.get_file('train_tf_processed.tfrecord',
'https://storage.googleapis.com/civil_comments_dataset/train_tf_processed.tfrecord')
validate_tf_file = tf.keras.utils.get_file('validate_tf_processed.tfrecord',
'https://storage.googleapis.com/civil_comments_dataset/validate_tf_processed.tfrecord')
使用 TFDV 分析資料,並找出其中可能存在的問題,例如遺漏值和資料不平衡,這些問題可能會導致公平性差異。
stats = tfdv.generate_statistics_from_tfrecord(data_location=train_tf_file)
tfdv.visualize_statistics(stats)
TFDV 顯示資料中存在一些顯著的不平衡,這可能會導致模型結果產生偏差。
毒性標籤 (模型預測的值) 不平衡。訓練集中只有 8% 的範例是有害的,這表示分類器可以透過預測所有留言都是無害的來獲得 92% 的準確度。
在與身分詞彙相關的欄位中,在 108 萬個訓練範例中,只有 6.6 千個 (0.61%) 與同性戀有關,而與雙性戀相關的範例則更少。這表示由於缺乏訓練資料,這些切片的效能可能會受到影響。
準備資料
定義特徵對應以剖析資料。每個範例都會有標籤、留言文字和與文字相關聯的身分特徵 性傾向
、性別
、宗教
、種族
和 身心障礙
。
BASE_DIR = tempfile.gettempdir()
TEXT_FEATURE = 'comment_text'
LABEL = 'toxicity'
FEATURE_MAP = {
# Label:
LABEL: tf.io.FixedLenFeature([], tf.float32),
# Text:
TEXT_FEATURE: tf.io.FixedLenFeature([], tf.string),
# Identities:
'sexual_orientation':tf.io.VarLenFeature(tf.string),
'gender':tf.io.VarLenFeature(tf.string),
'religion':tf.io.VarLenFeature(tf.string),
'race':tf.io.VarLenFeature(tf.string),
'disability':tf.io.VarLenFeature(tf.string),
}
接下來,設定輸入函式以將資料饋送至模型。在每個範例中新增權重欄,並提高有害範例的權重,以彌補 TFDV 識別出的類別不平衡問題。在評估階段僅使用身分特徵,因為在訓練期間只有留言會饋送至模型。
def train_input_fn():
def parse_function(serialized):
parsed_example = tf.io.parse_single_example(
serialized=serialized, features=FEATURE_MAP)
# Adds a weight column to deal with unbalanced classes.
parsed_example['weight'] = tf.add(parsed_example[LABEL], 0.1)
return (parsed_example,
parsed_example[LABEL])
train_dataset = tf.data.TFRecordDataset(
filenames=[train_tf_file]).map(parse_function).batch(512)
return train_dataset
訓練模型
在資料上建立並訓練深度學習模型。
model_dir = os.path.join(BASE_DIR, 'train', datetime.now().strftime(
"%Y%m%d-%H%M%S"))
embedded_text_feature_column = hub.text_embedding_column(
key=TEXT_FEATURE,
module_spec='https://tfhub.dev/google/nnlm-en-dim128/1')
classifier = tf.estimator.DNNClassifier(
hidden_units=[500, 100],
weight_column='weight',
feature_columns=[embedded_text_feature_column],
optimizer=tf.keras.optimizers.legacy.Adagrad(learning_rate=0.003),
loss_reduction=tf.losses.Reduction.SUM,
n_classes=2,
model_dir=model_dir)
classifier.train(input_fn=train_input_fn, steps=1000)
分析模型
取得經過訓練的模型後,分析模型以使用 TFMA 和公平性指標計算公平性指標。首先將模型匯出為 SavedModel。
匯出 SavedModel
def eval_input_receiver_fn():
serialized_tf_example = tf.compat.v1.placeholder(
dtype=tf.string, shape=[None], name='input_example_placeholder')
# This *must* be a dictionary containing a single key 'examples', which
# points to the input placeholder.
receiver_tensors = {'examples': serialized_tf_example}
features = tf.io.parse_example(serialized_tf_example, FEATURE_MAP)
features['weight'] = tf.ones_like(features[LABEL])
return tfma.export.EvalInputReceiver(
features=features,
receiver_tensors=receiver_tensors,
labels=features[LABEL])
tfma_export_dir = tfma.export.export_eval_savedmodel(
estimator=classifier,
export_dir_base=os.path.join(BASE_DIR, 'tfma_eval_model'),
eval_input_receiver_fn=eval_input_receiver_fn)
計算公平性指標
使用右側面板中的下拉式選單,選取要計算指標的身分,以及是否使用信賴區間執行。
公平性指標計算選項
使用 What-If Tool 將資料視覺化
在本節中,您將使用 What-If Tool 的互動式視覺介面,以在微觀層級探索和操作資料。
右側面板上的散佈圖中的每個點都代表載入工具的子集中的一個範例。按一下其中一個點,即可在左側面板中查看此特定範例的詳細資訊。系統會顯示留言文字、實際毒性和適用的身分。在此左側面板的底部,您會看到剛訓練之模型的推論結果。
修改範例的文字,然後按一下「執行推論」按鈕,以檢視您的變更如何導致感知到的毒性預測發生變化。
DEFAULT_MAX_EXAMPLES = 1000
# Load 100000 examples in memory. When first rendered,
# What-If Tool should only display 1000 of these due to browser constraints.
def wit_dataset(file, num_examples=100000):
dataset = tf.data.TFRecordDataset(
filenames=[file]).take(num_examples)
return [tf.train.Example.FromString(d.numpy()) for d in dataset]
wit_data = wit_dataset(train_tf_file)
config_builder = WitConfigBuilder(wit_data[:DEFAULT_MAX_EXAMPLES]).set_estimator_and_feature_spec(
classifier, FEATURE_MAP).set_label_vocab(['non-toxicity', LABEL]).set_target_feature(LABEL)
wit = WitWidget(config_builder)
轉譯公平性指標
使用匯出的評估結果轉譯公平性指標小工具。
您將在下方看到長條圖,顯示每個資料切片在選取指標上的效能。您可以使用視覺化頂端的下拉式選單,調整基準比較切片以及顯示的閾值。
公平性指標小工具與上方轉譯的 What-If Tool 整合。如果您在長條圖中選取一個資料切片,What-If Tool 就會更新以向您顯示來自所選切片的範例。當資料在上方 What-If Tool 中重新載入時,請嘗試將「Color By」修改為「toxicity」。這可讓您視覺化瞭解每個切片的範例毒性平衡。
event_handlers={'slice-selected':
wit.create_selection_callback(wit_data, DEFAULT_MAX_EXAMPLES)}
widget_view.render_fairness_indicator(eval_result=eval_result,
slicing_column=slice_selection,
event_handlers=event_handlers
)
對於此特定資料集和工作,某些身分系統性地出現較高的誤判率和偽陰性率可能會導致負面後果。例如,在內容審核系統中,特定群組的誤判率高於整體誤判率可能會導致這些聲音遭到壓制。因此,在您開發和改進模型時,定期評估這些類型的標準非常重要,並利用公平性指標、TFDV 和 WIT 等工具來協助闡明潛在問題。一旦您識別出公平性問題,就可以嘗試新的資料來源、資料平衡或其他技術,以改善效能不佳群組的效能。
如需更多資訊和關於如何使用公平性指標的指南,請參閱此處。
使用公平性評估結果
eval_result
物件 (在 render_fairness_indicator()
中於上方轉譯) 具有自己的 API,您可以運用此 API 將 TFMA 結果讀取到您的程式中。
取得評估的切片和指標
使用 get_slice_names()
和 get_metric_names()
分別取得評估的切片和指標。
pp = pprint.PrettyPrinter()
print("Slices:")
pp.pprint(eval_result.get_slice_names())
print("\nMetrics:")
pp.pprint(eval_result.get_metric_names())
使用 get_metrics_for_slice()
取得特定切片的指標,以字典形式將指標名稱對應到 指標值。
baseline_slice = ()
heterosexual_slice = (('sexual_orientation', 'heterosexual'),)
print("Baseline metric values:")
pp.pprint(eval_result.get_metrics_for_slice(baseline_slice))
print("\nHeterosexual metric values:")
pp.pprint(eval_result.get_metrics_for_slice(heterosexual_slice))
使用 get_metrics_for_all_slices()
取得所有切片的指標,以字典形式將每個切片對應到您從對其執行 get_metrics_for_slice()
取得的對應指標字典。
pp.pprint(eval_result.get_metrics_for_all_slices())