![]() |
![]() |
![]() |
![]() |
總覽
圖形正規化是神經圖形學習 (Neural Graph Learning) 更廣泛範例 (Bui 等人,2018 年) 下的一種特定技術。核心概念是使用圖形正規化目標訓練神經網路模型,同時利用標記和未標記資料。
在本教學課程中,我們將探索使用圖形正規化來分類形成自然 (有機) 圖形的文件。
使用神經結構化學習 (NSL) 架構建立圖形正規化模型的一般方法如下
- 從輸入圖形產生訓練資料和範例特徵。圖形中的節點對應於範例,而圖形中的邊緣對應於範例組之間的相似性。產生的訓練資料將包含鄰近特徵以及原始節點特徵。
- 使用
Keras
sequential、functional 或 subclass API 建立神經網路作為基本模型。 - 使用 NSL 架構提供的
GraphRegularization
包裝函式類別包裝基本模型,以建立新的圖形Keras
模型。這個新模型會將圖形正規化損失納入為其訓練目標中的正規化項。 - 訓練和評估圖形
Keras
模型。
設定
安裝神經結構化學習套件。
pip install --quiet neural-structured-learning
依附元件和匯入
import neural_structured_learning as nsl
import tensorflow as tf
# Resets notebook state
tf.keras.backend.clear_session()
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
"GPU is",
"available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
2023-11-16 12:04:49.460421: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-16 12:04:49.460472: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-16 12:04:49.461916: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered Version: 2.15.0 Eager mode: True GPU is NOT AVAILABLE 2023-11-16 12:04:51.768240: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Cora 資料集
Cora 資料集是引文圖形,其中節點代表機器學習論文,邊緣代表論文組之間的引文。涉及的任務是文件分類,目標是將每篇論文分類到 7 個類別之一。換句話說,這是具有 7 個類別的多類別分類問題。
圖形
原始圖形是有向圖。但是,為了本範例的目的,我們考慮此圖形的無向版本。因此,如果論文 A 引用論文 B,我們也認為論文 B 引用了 A。雖然這不一定是真的,但在本範例中,我們將引文視為相似性的代理,而相似性通常是可交換屬性。
特徵
輸入中的每篇論文實際上包含 2 個特徵
字詞:論文中文字的密集多熱詞袋表示法。Cora 資料集的詞彙包含 1433 個不重複的字詞。因此,此特徵的長度為 1433,位置「i」的值為 0/1,表示詞彙中字詞「i」是否存在於給定的論文中。
標籤:代表論文類別 ID (類別) 的單一整數。
下載 Cora 資料集
wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/ cora/README cora/cora.cites cora/cora.content
將 Cora 資料轉換為 NSL 格式
為了預先處理 Cora 資料集並將其轉換為神經結構化學習所需的格式,我們將執行 'preprocess_cora_dataset.py' 指令碼,該指令碼包含在 NSL github 儲存庫中。此指令碼會執行下列操作
- 使用原始節點特徵和圖形產生鄰近特徵。
- 產生包含
tf.train.Example
執行個體的訓練和測試資料分割。 - 以
TFRecord
格式保存產生的訓練和測試資料。
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2023-11-16 12:04:52-- https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11640 (11K) [text/plain] Saving to: ‘preprocess_cora_dataset.py’ preprocess_cora_dat 100%[===================>] 11.37K --.-KB/s in 0s 2023-11-16 12:04:53 (75.6 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640] 2023-11-16 12:04:53.758687: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-16 12:04:53.758743: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-16 12:04:53.760530: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-11-16 12:04:55.968449: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Reading graph file: /tmp/cora/cora.cites... Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds). Making all edges bi-directional... Done (0.01 seconds). Total graph nodes: 2708 Joining seed and neighbor tf.train.Examples with graph edges... Done creating and writing 2155 merged tf.train.Examples (1.44 seconds). Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)] Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr. Output test data written to TFRecord file: /tmp/cora/test_examples.tfr. Total running time: 0.05 minutes.
全域變數
訓練和測試資料的檔案路徑是根據用於叫用上述 'preprocess_cora_dataset.py' 指令碼的命令列旗標值。
### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'
### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
超參數
我們將使用 HParams
的執行個體來包含用於訓練和評估的各種超參數和常數。我們在下面簡要說明每個參數
num_classes:總共有 7 個不同的類別
max_seq_length:這是詞彙的大小,且輸入中的所有執行個體都具有密集的多元熱詞袋表示法。換句話說,字詞的值為 1 表示字詞存在於輸入中,值為 0 表示不存在。
distance_type:這是用於正規化範例及其鄰近項的距離度量。
graph_regularization_multiplier:這會控制圖形正規化項在整體損失函數中的相對權重。
num_neighbors:用於圖形正規化的鄰近項數。此值必須小於或等於執行
preprocess_cora_dataset.py
時使用的max_nbrs
命令列引數。num_fc_units:神經網路中完全連線層的數量。
train_epochs:訓練週期數。
batch_size:用於訓練和評估的批次大小。
dropout_rate:控制每個完全連線層之後的 dropout 率
eval_steps:在認定評估完成之前要處理的批次數。如果設定為
None
,則會評估測試集中的所有執行個體。
class HParams(object):
"""Hyperparameters used for training."""
def __init__(self):
### dataset parameters
self.num_classes = 7
self.max_seq_length = 1433
### neural graph learning parameters
self.distance_type = nsl.configs.DistanceType.L2
self.graph_regularization_multiplier = 0.1
self.num_neighbors = 1
### model architecture
self.num_fc_units = [50, 50]
### training parameters
self.train_epochs = 100
self.batch_size = 128
self.dropout_rate = 0.5
### eval parameters
self.eval_steps = None # All instances in the test set are evaluated.
HPARAMS = HParams()
載入訓練和測試資料
如本筆記本稍早所述,輸入訓練和測試資料已由 'preprocess_cora_dataset.py' 建立。我們會將它們載入到兩個 tf.data.Dataset
物件中 - 一個用於訓練,另一個用於測試。
在模型的輸入層中,我們不僅會從每個範例中擷取「字詞」和「標籤」特徵,還會根據 hparams.num_neighbors
值擷取對應的鄰近特徵。鄰近項數少於 hparams.num_neighbors
的執行個體將會為這些不存在的鄰近特徵指派虛擬值。
def make_dataset(file_path, training=False):
"""Creates a `tf.data.TFRecordDataset`.
Args:
file_path: Name of the file in the `.tfrecord` format containing
`tf.train.Example` objects.
training: Boolean indicating if we are in training mode.
Returns:
An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
objects.
"""
def parse_example(example_proto):
"""Extracts relevant fields from the `example_proto`.
Args:
example_proto: An instance of `tf.train.Example`.
Returns:
A pair whose first value is a dictionary containing relevant features
and whose second value contains the ground truth label.
"""
# The 'words' feature is a multi-hot, bag-of-words representation of the
# original raw text. A default value is required for examples that don't
# have the feature.
feature_spec = {
'words':
tf.io.FixedLenFeature([HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0,
dtype=tf.int64,
shape=[HPARAMS.max_seq_length])),
'label':
tf.io.FixedLenFeature((), tf.int64, default_value=-1),
}
# We also extract corresponding neighbor features in a similar manner to
# the features above during training.
if training:
for i in range(HPARAMS.num_neighbors):
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
NBR_WEIGHT_SUFFIX)
feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
[HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))
# We assign a default value of 0.0 for the neighbor weight so that
# graph regularization is done on samples based on their exact number
# of neighbors. In other words, non-existent neighbors are discounted.
feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
[1], tf.float32, default_value=tf.constant([0.0]))
features = tf.io.parse_single_example(example_proto, feature_spec)
label = features.pop('label')
return features, label
dataset = tf.data.TFRecordDataset([file_path])
if training:
dataset = dataset.shuffle(10000)
dataset = dataset.map(parse_example)
dataset = dataset.batch(HPARAMS.batch_size)
return dataset
train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)
讓我們快速查看訓練資料集以查看其內容。
for feature_batch, label_batch in train_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
print('Batch of neighbor weights:',
tf.reshape(feature_batch[nbr_weight_key], [-1]))
print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 1 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor weights: tf.Tensor( [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32) Batch of labels: tf.Tensor( [2 2 3 6 6 4 3 1 3 4 2 5 4 5 6 4 1 5 1 0 5 6 3 0 4 2 4 4 1 1 1 6 2 2 5 3 3 5 3 2 0 0 1 5 5 0 4 6 1 4 2 0 2 4 4 1 3 2 2 2 1 2 2 5 2 2 4 1 2 6 1 6 3 0 5 2 6 4 3 2 4 0 2 1 2 2 2 2 2 2 1 1 6 3 2 4 1 2 1 0 3 0 0 3 2 6 1 2 2 1 2 2 2 3 2 0 2 3 2 5 3 0 1 1 2 0 2 1], shape=(128,), dtype=int64)
讓我們快速查看測試資料集以查看其內容。
for feature_batch, label_batch in test_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
print('Batch of labels:', label_batch)
Feature list: ['words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of labels: tf.Tensor( [5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)
模型定義
為了示範圖形正規化的使用,我們先為此問題建構基本模型。我們將使用具有 2 個隱藏層且中間有 dropout 的簡單前饋神經網路。我們說明如何使用 tf.Keras
架構支援的所有模型類型 (sequential、functional 和 subclass) 建立基本模型。
Sequential 基本模型
def make_mlp_sequential_model(hparams):
"""Creates a sequential multi-layer perceptron model."""
model = tf.keras.Sequential()
model.add(
tf.keras.layers.InputLayer(
input_shape=(hparams.max_seq_length,), name='words'))
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
model.add(
tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
for num_units in hparams.num_fc_units:
model.add(tf.keras.layers.Dense(num_units, activation='relu'))
# For sequential models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
model.add(tf.keras.layers.Dense(hparams.num_classes))
return model
Functional 基本模型
def make_mlp_functional_model(hparams):
"""Creates a functional API-based multi-layer perceptron model."""
inputs = tf.keras.Input(
shape=(hparams.max_seq_length,), dtype='int64', name='words')
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
cur_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))(
inputs)
for num_units in hparams.num_fc_units:
cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
# For functional models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)
outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)
model = tf.keras.Model(inputs, outputs=outputs)
return model
Subclass 基本模型
def make_mlp_subclass_model(hparams):
"""Creates a multi-layer perceptron subclass model in Keras."""
class MLP(tf.keras.Model):
"""Subclass model defining a multi-layer perceptron."""
def __init__(self):
super(MLP, self).__init__()
# Input is already one-hot encoded in the integer format. We create a
# layer to cast it to floating point format here.
self.cast_to_float_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))
self.dense_layers = [
tf.keras.layers.Dense(num_units, activation='relu')
for num_units in hparams.num_fc_units
]
self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
self.output_layer = tf.keras.layers.Dense(hparams.num_classes)
def call(self, inputs, training=False):
cur_layer = self.cast_to_float_layer(inputs['words'])
for dense_layer in self.dense_layers:
cur_layer = dense_layer(cur_layer)
cur_layer = self.dropout_layer(cur_layer, training=training)
outputs = self.output_layer(cur_layer)
return outputs
return MLP()
建立基本模型
# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= words (InputLayer) [(None, 1433)] 0 lambda (Lambda) (None, 1433) 0 dense (Dense) (None, 50) 71700 dropout (Dropout) (None, 50) 0 dense_1 (Dense) (None, 50) 2550 dropout_1 (Dropout) (None, 50) 0 dense_2 (Dense) (None, 7) 357 ================================================================= Total params: 74607 (291.43 KB) Trainable params: 74607 (291.43 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
訓練基本 MLP 模型
# Compile and train the base MLP model
base_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/functional.py:642: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model. inputs = self._flatten_to_reference_inputs(inputs) 17/17 [==============================] - 1s 6ms/step - loss: 1.9105 - accuracy: 0.2260 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8280 - accuracy: 0.3044 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7240 - accuracy: 0.3299 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5969 - accuracy: 0.3745 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4765 - accuracy: 0.4492 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3235 - accuracy: 0.5276 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1913 - accuracy: 0.5889 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0604 - accuracy: 0.6432 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9628 - accuracy: 0.6821 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8601 - accuracy: 0.7234 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7914 - accuracy: 0.7480 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7230 - accuracy: 0.7633 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6783 - accuracy: 0.7791 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6019 - accuracy: 0.8070 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5587 - accuracy: 0.8367 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5295 - accuracy: 0.8450 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4789 - accuracy: 0.8599 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4474 - accuracy: 0.8650 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4148 - accuracy: 0.8701 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3812 - accuracy: 0.8896 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3656 - accuracy: 0.8863 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3544 - accuracy: 0.8923 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3050 - accuracy: 0.9165 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2858 - accuracy: 0.9216 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2821 - accuracy: 0.9234 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2543 - accuracy: 0.9276 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2477 - accuracy: 0.9285 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2413 - accuracy: 0.9295 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2153 - accuracy: 0.9415 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2241 - accuracy: 0.9290 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2118 - accuracy: 0.9374 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2041 - accuracy: 0.9471 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1951 - accuracy: 0.9392 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1841 - accuracy: 0.9443 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1783 - accuracy: 0.9522 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1742 - accuracy: 0.9485 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1705 - accuracy: 0.9541 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1507 - accuracy: 0.9592 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1513 - accuracy: 0.9555 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1378 - accuracy: 0.9652 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1471 - accuracy: 0.9587 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1309 - accuracy: 0.9661 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1288 - accuracy: 0.9596 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1327 - accuracy: 0.9629 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1170 - accuracy: 0.9675 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1198 - accuracy: 0.9666 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1183 - accuracy: 0.9680 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1025 - accuracy: 0.9740 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0981 - accuracy: 0.9754 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1076 - accuracy: 0.9708 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0874 - accuracy: 0.9796 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1027 - accuracy: 0.9735 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0993 - accuracy: 0.9740 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0934 - accuracy: 0.9759 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0932 - accuracy: 0.9759 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0787 - accuracy: 0.9810 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0890 - accuracy: 0.9754 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0918 - accuracy: 0.9749 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0908 - accuracy: 0.9717 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0825 - accuracy: 0.9777 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0926 - accuracy: 0.9684 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0702 - accuracy: 0.9800 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0720 - accuracy: 0.9842 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0792 - accuracy: 0.9773 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0760 - accuracy: 0.9782 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0736 - accuracy: 0.9800 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0838 - accuracy: 0.9773 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0639 - accuracy: 0.9824 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0742 - accuracy: 0.9805 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0798 - accuracy: 0.9782 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0694 - accuracy: 0.9805 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0635 - accuracy: 0.9833 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0587 - accuracy: 0.9824 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0689 - accuracy: 0.9828 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0628 - accuracy: 0.9828 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0570 - accuracy: 0.9842 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0632 - accuracy: 0.9824 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0673 - accuracy: 0.9782 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0573 - accuracy: 0.9828 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0640 - accuracy: 0.9824 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0610 - accuracy: 0.9810 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0553 - accuracy: 0.9861 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0482 - accuracy: 0.9879 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0548 - accuracy: 0.9842 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0537 - accuracy: 0.9865 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0540 - accuracy: 0.9828 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0528 - accuracy: 0.9838 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0505 - accuracy: 0.9865 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0473 - accuracy: 0.9833 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0604 - accuracy: 0.9810 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0469 - accuracy: 0.9879 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0554 - accuracy: 0.9810 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0427 - accuracy: 0.9875 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0581 - accuracy: 0.9824 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0488 - accuracy: 0.9842 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0466 - accuracy: 0.9875 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0465 - accuracy: 0.9875 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0411 - accuracy: 0.9879 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0539 - accuracy: 0.9852 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0451 - accuracy: 0.9870 <keras.src.callbacks.History at 0x7f459c2e9e50>
評估基本 MLP 模型
# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
"""Prints evaluation metrics.
Args:
model_desc: A description of the model.
eval_metrics: A dictionary mapping metric names to corresponding values. It
must contain the loss and accuracy metrics.
"""
print('\n')
print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
if 'graph_loss' in eval_metrics:
print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
zip(base_model.metrics_names,
base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.4164 - accuracy: 0.7758 Eval accuracy for Base MLP model : 0.775768518447876 Eval loss for Base MLP model : 1.4164185523986816
使用圖形正規化訓練 MLP 模型
將圖形正規化併入現有 tf.Keras.Model
的損失項中只需要幾行程式碼。基本模型會被包裝以建立新的 tf.Keras
子類別模型,其損失包括圖形正規化。
為了評估圖形正規化的增量效益,我們將建立新的基本模型執行個體。這是因為 base_model
已經訓練了幾個反覆運算,而重複使用此訓練模型來建立圖形正規化模型對於 base_model
而言並非公平的比較。
# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
max_neighbors=HPARAMS.num_neighbors,
multiplier=HPARAMS.graph_regularization_multiplier,
distance_type=HPARAMS.distance_type,
sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
graph_reg_config)
graph_reg_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 17/17 [==============================] - 2s 7ms/step - loss: 1.9586 - accuracy: 0.2107 - scaled_graph_loss: 0.0319 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8903 - accuracy: 0.2942 - scaled_graph_loss: 0.0282 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8290 - accuracy: 0.3262 - scaled_graph_loss: 0.0411 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7762 - accuracy: 0.3248 - scaled_graph_loss: 0.0604 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7334 - accuracy: 0.3568 - scaled_graph_loss: 0.0792 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6859 - accuracy: 0.3735 - scaled_graph_loss: 0.0920 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6506 - accuracy: 0.3935 - scaled_graph_loss: 0.1086 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6028 - accuracy: 0.4520 - scaled_graph_loss: 0.1249 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5690 - accuracy: 0.5012 - scaled_graph_loss: 0.1386 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5332 - accuracy: 0.5420 - scaled_graph_loss: 0.1577 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4792 - accuracy: 0.5842 - scaled_graph_loss: 0.1642 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4438 - accuracy: 0.6306 - scaled_graph_loss: 0.1909 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4155 - accuracy: 0.6617 - scaled_graph_loss: 0.2009 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3596 - accuracy: 0.6896 - scaled_graph_loss: 0.1964 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3462 - accuracy: 0.7077 - scaled_graph_loss: 0.2294 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3151 - accuracy: 0.7295 - scaled_graph_loss: 0.2312 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2848 - accuracy: 0.7555 - scaled_graph_loss: 0.2319 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2643 - accuracy: 0.7759 - scaled_graph_loss: 0.2469 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2434 - accuracy: 0.7921 - scaled_graph_loss: 0.2544 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2005 - accuracy: 0.8093 - scaled_graph_loss: 0.2473 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2007 - accuracy: 0.8070 - scaled_graph_loss: 0.2688 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1876 - accuracy: 0.8135 - scaled_graph_loss: 0.2708 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1729 - accuracy: 0.8274 - scaled_graph_loss: 0.2662 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1543 - accuracy: 0.8376 - scaled_graph_loss: 0.2707 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1228 - accuracy: 0.8538 - scaled_graph_loss: 0.2677 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1166 - accuracy: 0.8603 - scaled_graph_loss: 0.2785 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1176 - accuracy: 0.8473 - scaled_graph_loss: 0.2807 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1085 - accuracy: 0.8473 - scaled_graph_loss: 0.2649 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0751 - accuracy: 0.8691 - scaled_graph_loss: 0.2858 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0851 - accuracy: 0.8696 - scaled_graph_loss: 0.2996 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0932 - accuracy: 0.8770 - scaled_graph_loss: 0.2892 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0619 - accuracy: 0.8821 - scaled_graph_loss: 0.2880 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0531 - accuracy: 0.8886 - scaled_graph_loss: 0.2847 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0558 - accuracy: 0.8863 - scaled_graph_loss: 0.2962 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0375 - accuracy: 0.8891 - scaled_graph_loss: 0.2780 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0310 - accuracy: 0.8858 - scaled_graph_loss: 0.2932 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0269 - accuracy: 0.8872 - scaled_graph_loss: 0.2916 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0273 - accuracy: 0.8928 - scaled_graph_loss: 0.2948 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9935 - accuracy: 0.9123 - scaled_graph_loss: 0.2910 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0083 - accuracy: 0.9104 - scaled_graph_loss: 0.2951 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0196 - accuracy: 0.8951 - scaled_graph_loss: 0.2982 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9941 - accuracy: 0.9007 - scaled_graph_loss: 0.2898 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0069 - accuracy: 0.9012 - scaled_graph_loss: 0.3076 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9816 - accuracy: 0.9049 - scaled_graph_loss: 0.2930 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9910 - accuracy: 0.9104 - scaled_graph_loss: 0.2954 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9949 - accuracy: 0.9026 - scaled_graph_loss: 0.3111 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9715 - accuracy: 0.9114 - scaled_graph_loss: 0.2830 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9796 - accuracy: 0.9067 - scaled_graph_loss: 0.2970 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9570 - accuracy: 0.9114 - scaled_graph_loss: 0.2936 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9691 - accuracy: 0.9049 - scaled_graph_loss: 0.2940 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9803 - accuracy: 0.9114 - scaled_graph_loss: 0.3083 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9612 - accuracy: 0.9128 - scaled_graph_loss: 0.2860 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9627 - accuracy: 0.9216 - scaled_graph_loss: 0.3077 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9516 - accuracy: 0.9151 - scaled_graph_loss: 0.2906 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9431 - accuracy: 0.9197 - scaled_graph_loss: 0.2967 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9622 - accuracy: 0.9132 - scaled_graph_loss: 0.3053 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9410 - accuracy: 0.9188 - scaled_graph_loss: 0.2830 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9531 - accuracy: 0.9230 - scaled_graph_loss: 0.3049 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9309 - accuracy: 0.9193 - scaled_graph_loss: 0.3009 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9300 - accuracy: 0.9248 - scaled_graph_loss: 0.2988 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9173 - accuracy: 0.9244 - scaled_graph_loss: 0.2884 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9228 - accuracy: 0.9248 - scaled_graph_loss: 0.2960 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9394 - accuracy: 0.9174 - scaled_graph_loss: 0.3102 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9182 - accuracy: 0.9174 - scaled_graph_loss: 0.2899 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9276 - accuracy: 0.9253 - scaled_graph_loss: 0.2996 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9229 - accuracy: 0.9244 - scaled_graph_loss: 0.2912 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9325 - accuracy: 0.9142 - scaled_graph_loss: 0.3088 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9091 - accuracy: 0.9216 - scaled_graph_loss: 0.2883 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8987 - accuracy: 0.9267 - scaled_graph_loss: 0.2924 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9188 - accuracy: 0.9216 - scaled_graph_loss: 0.2970 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9003 - accuracy: 0.9299 - scaled_graph_loss: 0.2962 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9086 - accuracy: 0.9206 - scaled_graph_loss: 0.2944 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9047 - accuracy: 0.9304 - scaled_graph_loss: 0.3174 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9214 - accuracy: 0.9202 - scaled_graph_loss: 0.2923 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9081 - accuracy: 0.9276 - scaled_graph_loss: 0.3020 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9043 - accuracy: 0.9220 - scaled_graph_loss: 0.2892 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9022 - accuracy: 0.9253 - scaled_graph_loss: 0.2998 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8871 - accuracy: 0.9332 - scaled_graph_loss: 0.2979 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8863 - accuracy: 0.9295 - scaled_graph_loss: 0.3021 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8893 - accuracy: 0.9225 - scaled_graph_loss: 0.2928 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8850 - accuracy: 0.9258 - scaled_graph_loss: 0.2997 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9013 - accuracy: 0.9165 - scaled_graph_loss: 0.2961 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8739 - accuracy: 0.9253 - scaled_graph_loss: 0.2886 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8840 - accuracy: 0.9318 - scaled_graph_loss: 0.3040 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8628 - accuracy: 0.9378 - scaled_graph_loss: 0.2886 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8745 - accuracy: 0.9313 - scaled_graph_loss: 0.3013 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8678 - accuracy: 0.9327 - scaled_graph_loss: 0.2980 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8614 - accuracy: 0.9397 - scaled_graph_loss: 0.2947 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8589 - accuracy: 0.9327 - scaled_graph_loss: 0.2957 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8688 - accuracy: 0.9346 - scaled_graph_loss: 0.2996 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8661 - accuracy: 0.9216 - scaled_graph_loss: 0.2881 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8828 - accuracy: 0.9318 - scaled_graph_loss: 0.3019 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8701 - accuracy: 0.9374 - scaled_graph_loss: 0.3051 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8572 - accuracy: 0.9383 - scaled_graph_loss: 0.2998 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8765 - accuracy: 0.9327 - scaled_graph_loss: 0.2999 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8685 - accuracy: 0.9336 - scaled_graph_loss: 0.3013 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8710 - accuracy: 0.9378 - scaled_graph_loss: 0.3023 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8746 - accuracy: 0.9327 - scaled_graph_loss: 0.2956 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8642 - accuracy: 0.9341 - scaled_graph_loss: 0.2984 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8638 - accuracy: 0.9318 - scaled_graph_loss: 0.2965 <keras.src.callbacks.History at 0x7f445862f130>
評估具有圖形正規化的 MLP 模型
eval_results = dict(
zip(graph_reg_model.metrics_names,
graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 0.8791 - accuracy: 0.7993 Eval accuracy for MLP + graph regularization : 0.7992766499519348 Eval loss for MLP + graph regularization : 0.8790676593780518
圖形正規化模型的準確度比基本模型 (base_model
) 約高出 2-3%。
結論
我們已示範使用圖形正規化,透過神經結構化學習 (NSL) 架構,在自然引文圖形 (Cora) 上進行文件分類。我們的進階教學課程涉及在訓練具有圖形正規化的神經網路之前,根據範例嵌入合成圖形。如果輸入未包含明確的圖形,則此方法很有用。
我們鼓勵使用者透過調整監督量以及為圖形正規化嘗試不同的神經架構來進一步實驗。