建立自訂反事實 Logit 配對資料集

將反事實 Logit 配對 (CLP) 應用於評估和改善模型的公平性,需要反事實資料集。您可以複製現有的資料集,並變更新的資料集以新增、移除或修改身分術語,來建立反事實資料集。本教學課程說明為現有文字資料集建立反事實資料集的方法和技巧。

您可以將反事實資料集與 CLP 技術搭配使用,方法是建立新的資料物件 CounterfactualPackedInputs,其中包含 original_inputcounterfactual_data,且看起來如下所示

CounterfactualPackedInputs 看起來如下所示

CounterfactualPackedInputs(
  original_input=(x, y, sample_weight),
  counterfactual_data=(original_x, counterfactual_x,
                       counterfactual_sample_weight)
)

original_input 應為用於訓練 Keras 模型的原始資料集。counterfactual_data 應為具有原始 x 值、對應的 counterfactual_x 值和 counterfactual_sample_weighttf.data.Datasetcounterfactual_x 值幾乎與原始值相同,但移除或取代了一個或多個屬性。此資料集用於配對原始值和反事實值之間的損失函數,目標是確保模型的預測在敏感屬性不同時不會變更。original_inputcounterfactual_data 需要具有相同的形狀。您可以複製 counterfactual_data 中的值,使其與 original_input 的元素數量相同。

counterfactual_data 的屬性

  • 所有 original_x 值都需要參考身分群組
  • 每個 counterfactual_x 值都與原始值相同,但移除或取代了一個或多個屬性
  • 具有與原始輸入相同的形狀 (您可以複製值,使其具有相同的形狀)

counterfactual_data 不需要

  • 與原始輸入中的資料重疊
  • 具有基本事實標籤

以下範例說明移除「gay」一詞時 counterfactual_data 的外觀。

original_x: “I am a gay man”
counterfactual_x: “I am a man” 
counterfactual_sample_weight”: 1

如果您有文字分類器,可以使用 build_counterfactual_data 來協助建立反事實資料集。對於所有其他資料類型,您需要直接提供反事實資料集。

設定

首先,您將安裝 TensorFlow 模型修復。

pip install --upgrade tensorflow-model-remediation
import tensorflow as tf
from tensorflow_model_remediation import counterfactual

建立簡單的資料集

為了示範目的,我們將使用 build_counterfactual_dataset 從原始輸入建立反事實資料。請注意,您也可以從未標記的資料建構反事實資料 (而不是從原始輸入建構)。您將建立一個包含一個句子的簡單資料集:「i am a gay man」,這將作為 original_input

建立反事實資料集

由於這是文字分類器,您可以使用兩種方式透過 build_counterfactual_data 建立反事實資料集

  1. 移除詞彙:使用 build_counterfactual_data 傳遞將透過 tf.strings.regex_replace 從資料集中移除的字詞清單。
  2. 取代詞彙:建立自訂函式以傳遞至 build_counterfactual_data。這可能包括使用更具體的 regex 函式來取代原始資料集中的字詞,或支援非文字特徵

build_counterfactual_dataset 接受 original_input,並根據您傳遞的選用參數移除或取代詞彙。在大多數情況下,移除詞彙 (選項 1) 應足以執行 CLP,但是,傳遞自訂函式 (選項 2) 可更精確地控制反事實值。

選項 1:要移除的字詞清單

傳入要使用 build_counterfactual_data 移除的性別相關詞彙清單。

使用簡單的 regex 建立反事實資料集時,請記住,這可能會擴增不應變更的字詞。最佳做法是檢查對 counterfactual_x 值所做的變更在 orginal_x 值的上下文中是否合理。此外,build_counterfactual_dataset 將僅傳回包含反事實執行個體的值。這可能會導致資料集形狀與 orginal_input 不同,但在傳遞至 pack_counterfactual_data 時會調整大小。

simple_dataset_x = tf.constant(
    ["I am a gay man" + str(i) for i in range(10)] +
    ["I am a man" + str(i) for i in range(10)])
print("Length of starting values: " + str(len(simple_dataset_x)))

simple_dataset = tf.data.Dataset.from_tensor_slices(
            (simple_dataset_x, None, None))

counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(
    original_input=simple_dataset,
    sensitive_terms_to_remove=['gay'])

# Inspect the content of the TF Counterfactual Dataset
for original_value, counterfactual_value, _ in counterfactual_data.take(1):
  print("original: ", original_value)
  print("counterfactual: ", counterfactual_value)
print("Length of dataset after build_counterfactual_data: " +
      str(len(list(counterfactual_data))))

選項 2:自訂函式

為了更彈性地修改原始資料集,您可以改為將自訂函式傳遞至 build_counterfactual_data

在範例中,您可以考慮將參考男性的身分術語取代為參考女性的術語。這可以透過編寫函式來取代字詞字典來完成。

請注意,自訂函式的唯一限制是它必須是可呼叫的,才能接受並傳回 Model.fit 中使用的格式的元組,並且應移除不包含任何變更的值,這可以透過將詞彙傳遞至 sensitive_terms_to_remove 來完成。

words_to_replace = {"man": "woman"}
print("Length of starting values: " + str(len(simple_dataset_x)))

def replace_words(original_batch):
  original_x, _, original_sample_weight = (
      tf.keras.utils.unpack_x_y_sample_weight(original_batch))
  for word in words_to_replace:
    counterfactual_x = tf.strings.regex_replace(
        original_x, f'{word}', words_to_replace[word])
  return tf.keras.utils.pack_x_y_sample_weight(
      original_x, counterfactual_x, sample_weight=original_sample_weight)

counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(
    original_input=simple_dataset,
    sensitive_terms_to_remove=['gay'],
    custom_counterfactual_function=replace_words)

# Inspect the content of the TF Counterfactual Dataset
for original_value, counterfactual_value in counterfactual_data.take(1):
  print("original: ", original_value)
  print("counterfactual: ", counterfactual_value)
print("Length of dataset after build_counterfactual_data: " +
      str(len(list(counterfactual_data))))

若要瞭解詳情,請參閱 build_counterfactual_data 的 API 文件。