使用 SimCLR 用對比預訓練模型實現半監督圖像分類的代碼實現
作者:András Béres
編譯:ronghuaiyang
導讀
在 STL-10 數據集上用 SimCLR 先做對比訓練,再進行少量標註數據的監督訓練微調。
半監督學習
========
半監督學習是一種處理部分標記數據集的機器學習範式。當在現實世界中應用深度學習時,通常需要收集一個大型數據集才能使其正常工作。然而,標記的代價隨着數據集的大小線性增長 (標記每個示例需要一個常數時間),模型性能只會隨着數據集的大小呈次線性增長 (https://arxiv.org/abs/2001.08361)。這意味着標記越來越多的樣本變得越來越沒有成本效益,而收集未標記的數據通常是便宜的,因爲它通常很容易大量獲得。
半監督學習通過只需要一個部分標記的數據集來解決這個問題,並且通過利用未標記的樣本來進行有效的學習。
在這個例子中,我們將使用完全不使用標籤的 STL-10 半監督數據集上的對比學習對編碼器進行預訓練,然後只使用其標籤子集對其進行微調。
對比學習
在最高層次上,對比學習背後的主要思想是以自我監督的方式學習對圖像增強不變性的表示。這個目標的一個問題是它有一個簡單的退化解:在這個表示是常數的情況下,根本不依賴於輸入圖像。
對比學習通過如下方式修改目標來避免這個陷阱:它將同一圖像的增強版本 / 視圖的表徵彼此拉近 (收縮正樣本),同時將不同圖像彼此分開 (差異化負樣本)。
SimCLR 就是這樣一種對比方法,它從本質上確定了優化這一目標所需的核心組件,並可以通過擴展這種簡單方法實現高性能。
另一種方法是 SimSiam,它與 SimCLR 的主要區別在於前者沒有在其損失中使用任何負樣本。因此,它並不能顯式地阻止出現退化解,相反,通過架構設計,使用預測器網絡的非對稱編碼路徑和在最後的層中應用批標準化 (BatchNorm)) 隱式地避免了它。
設置
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
超參數
# Dataset hyperparameters
unlabeled_dataset_size = 100000
labeled_dataset_size = 5000
image_size = 96
image_channels = 3
# Algorithm hyperparameters
num_epochs = 20
batch_size = 525 # Corresponds to 200 steps per epoch
width = 128
temperature = 0.1
# Stronger augmentations for contrastive, weaker ones for supervised training
contrastive_augmentation = {"min_area": 0.25, "brightness": 0.6, "jitter": 0.2}
classification_augmentation = {"min_area": 0.75, "brightness": 0.3, "jitter": 0.1}
數據集
在訓練過程中,我們將同時加載大量未標記圖像和少量標記圖像。
def prepare_dataset():
# Labeled and unlabeled samples are loaded synchronously
# with batch sizes selected accordingly
steps_per_epoch = (unlabeled_dataset_size + labeled_dataset_size) // batch_size
unlabeled_batch_size = unlabeled_dataset_size // steps_per_epoch
labeled_batch_size = labeled_dataset_size // steps_per_epoch
print(
f"batch size is {unlabeled_batch_size} (unlabeled) + {labeled_batch_size} (labeled)"
)
unlabeled_train_dataset = (
tfds.load("stl10", split="unlabelled", as_supervised=True, shuffle_files=True)
.shuffle(buffer_size=10 * unlabeled_batch_size)
.batch(unlabeled_batch_size)
)
labeled_train_dataset = (
tfds.load("stl10", split="train", as_supervised=True, shuffle_files=True)
.shuffle(buffer_size=10 * labeled_batch_size)
.batch(labeled_batch_size)
)
test_dataset = (
tfds.load("stl10", split="test", as_supervised=True)
.batch(batch_size)
.prefetch(buffer_size=tf.data.AUTOTUNE)
)
# Labeled and unlabeled datasets are zipped together
train_dataset = tf.data.Dataset.zip(
(unlabeled_train_dataset, labeled_train_dataset)
).prefetch(buffer_size=tf.data.AUTOTUNE)
return train_dataset, labeled_train_dataset, test_dataset
# Load STL10 dataset
train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()
batch size is 500 (unlabeled) + 25 (labeled)
圖像增強
對比學習中最重要的兩種圖像增強方法如下:
-
裁剪:迫使模型對同一幅圖像的不同部分進行同樣的編碼,我們實現用 RandomTranslation 和 RandomZoom 層來實現。
-
顏色抖動:通過扭曲顏色直方圖來防止任務中基於顏色直方圖的簡單解決方案。實現這一點的一個原則方法是在顏色空間中進行仿射變換。
在這個例子中,我們也使用了隨機水平翻轉。較強的增強用於對比學習,較弱的增強用於監督分類,以避免在少數有標記的例子上過擬合。
我們實現隨機顏色抖動作爲自定義預處理層。使用預處理層進行數據增強有以下兩個優點:
-
數據增強將在 GPU 上批量運行,所以訓練不會在 CPU 資源受限的環境 (如 Colab 筆記本,或個人機器) 中被數據管道所阻礙。
-
部署更容易,因爲數據預處理管道封裝在模型中,並且在部署時不必重新實現。
# Distorts the color distibutions of images
class RandomColorAffine(layers.Layer):
def __init__(self, brightness=0, jitter=0, **kwargs):
super().__init__(**kwargs)
self.brightness = brightness
self.jitter = jitter
def call(self, images, training=True):
if training:
batch_size = tf.shape(images)[0]
# Same for all colors
brightness_scales = 1 + tf.random.uniform(
(batch_size, 1, 1, 1), minval=-self.brightness, maxval=self.brightness
)
# Different for all colors
jitter_matrices = tf.random.uniform(
(batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter
)
color_transforms = (
tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales
+ jitter_matrices
)
images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1)
return images
# Image augmentation module
def get_augmenter(min_area, brightness, jitter):
zoom_factor = 1.0 - tf.sqrt(min_area)
return keras.Sequential(
[
keras.Input(shape=(image_size, image_size, image_channels)),
preprocessing.Rescaling(1 / 255),
preprocessing.RandomFlip("horizontal"),
preprocessing.RandomTranslation(zoom_factor / 2, zoom_factor / 2),
preprocessing.RandomZoom((-zoom_factor, 0.0), (-zoom_factor, 0.0)),
RandomColorAffine(brightness, jitter),
]
)
def visualize_augmentations(num_images):
# Sample a batch from a dataset
images = next(iter(train_dataset))[0][0][:num_images]
# Apply augmentations
augmented_images = zip(
images,
get_augmenter(**classification_augmentation)(images),
get_augmenter(**contrastive_augmentation)(images),
get_augmenter(**contrastive_augmentation)(images),
)
row_titles = [
"Original:",
"Weakly augmented:",
"Strongly augmented:",
"Strongly augmented:",
]
plt.figure(figsize=(num_images * 2.2, 4 * 2.2), dpi=100)
for column, image_row in enumerate(augmented_images):
for row, image in enumerate(image_row):
plt.subplot(4, num_images, row * num_images + column + 1)
plt.imshow(image)
if column == 0:
plt.title(row_titles[row], loc="left")
plt.axis("off")
plt.tight_layout()
visualize_augmentations(num_images=8)
編碼器結構
# Define the encoder architecture
def get_encoder():
return keras.Sequential(
[
keras.Input(shape=(image_size, image_size, image_channels)),
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
layers.Flatten(),
layers.Dense(width, activation="relu"),
],
,
)
有監督的基線模型
使用隨機初始化方法訓練基線監督模型。
# Baseline supervised training with random initialization
baseline_model = keras.Sequential(
[
keras.Input(shape=(image_size, image_size, image_channels)),
get_augmenter(**classification_augmentation),
get_encoder(),
layers.Dense(10),
],
,
)
baseline_model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
baseline_history = baseline_model.fit(
labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
)
print(
"Maximal validation accuracy: {:.2f}%".format(
max(baseline_history.history["val_acc"]) * 100
)
)
Epoch 1/20
200/200 [==============================] - 8s 26ms/step - loss: 2.1769 - acc: 0.1794 - val_loss: 1.7424 - val_acc: 0.3341
Epoch 2/20
200/200 [==============================] - 3s 16ms/step - loss: 1.8366 - acc: 0.3139 - val_loss: 1.6184 - val_acc: 0.3989
Epoch 3/20
200/200 [==============================] - 3s 16ms/step - loss: 1.6331 - acc: 0.3912 - val_loss: 1.5344 - val_acc: 0.4125
Epoch 4/20
200/200 [==============================] - 3s 16ms/step - loss: 1.5439 - acc: 0.4216 - val_loss: 1.4052 - val_acc: 0.4712
Epoch 5/20
200/200 [==============================] - 4s 17ms/step - loss: 1.4576 - acc: 0.4575 - val_loss: 1.4337 - val_acc: 0.4729
Epoch 6/20
200/200 [==============================] - 3s 17ms/step - loss: 1.3723 - acc: 0.4875 - val_loss: 1.4054 - val_acc: 0.4746
Epoch 7/20
200/200 [==============================] - 3s 17ms/step - loss: 1.3445 - acc: 0.5066 - val_loss: 1.3030 - val_acc: 0.5200
Epoch 8/20
200/200 [==============================] - 3s 17ms/step - loss: 1.3015 - acc: 0.5255 - val_loss: 1.2720 - val_acc: 0.5378
Epoch 9/20
200/200 [==============================] - 3s 16ms/step - loss: 1.2244 - acc: 0.5452 - val_loss: 1.3211 - val_acc: 0.5220
Epoch 10/20
200/200 [==============================] - 3s 17ms/step - loss: 1.2204 - acc: 0.5494 - val_loss: 1.2898 - val_acc: 0.5381
Epoch 11/20
200/200 [==============================] - 4s 17ms/step - loss: 1.1359 - acc: 0.5766 - val_loss: 1.2138 - val_acc: 0.5648
Epoch 12/20
200/200 [==============================] - 3s 17ms/step - loss: 1.1228 - acc: 0.5855 - val_loss: 1.2602 - val_acc: 0.5429
Epoch 13/20
200/200 [==============================] - 3s 17ms/step - loss: 1.0853 - acc: 0.6000 - val_loss: 1.2716 - val_acc: 0.5591
Epoch 14/20
200/200 [==============================] - 3s 17ms/step - loss: 1.0632 - acc: 0.6078 - val_loss: 1.2832 - val_acc: 0.5591
Epoch 15/20
200/200 [==============================] - 3s 16ms/step - loss: 1.0268 - acc: 0.6157 - val_loss: 1.1712 - val_acc: 0.5882
Epoch 16/20
200/200 [==============================] - 3s 17ms/step - loss: 0.9594 - acc: 0.6440 - val_loss: 1.2904 - val_acc: 0.5573
Epoch 17/20
200/200 [==============================] - 3s 17ms/step - loss: 0.9524 - acc: 0.6517 - val_loss: 1.1854 - val_acc: 0.5955
Epoch 18/20
200/200 [==============================] - 3s 17ms/step - loss: 0.9118 - acc: 0.6672 - val_loss: 1.1974 - val_acc: 0.5845
Epoch 19/20
200/200 [==============================] - 3s 17ms/step - loss: 0.9187 - acc: 0.6686 - val_loss: 1.1703 - val_acc: 0.6025
Epoch 20/20
200/200 [==============================] - 3s 17ms/step - loss: 0.8520 - acc: 0.6911 - val_loss: 1.1312 - val_acc: 0.6149
Maximal validation accuracy: 61.49%
用自監督模型進行對比學習
我們對未標記的圖像進行對比損失的預訓練。一個非線性投影頭被附加到編碼器的頂部,因爲它提高了編碼器表示的質量。
我們使用 InfoNCE/NT-Xent/N-pairs 損失,可以用以下方式解釋:
-
我們將批處理中的每個圖像視爲它有自己的類。
-
然後,我們爲每個 “類” 提供兩個示例(一對增強視圖)。
-
每個視圖的表示與每個可能的對的表示相比較 (對於兩個增強版本)。
-
我們使用這一對錶示的餘弦相似度的溫度縮放後的值作爲 logits。
-
最後,我們使用分類交叉熵作爲 “分類” 損失。
以下兩個指標用於監控訓練前的性能:
-
對比精度 (SimCLR 表 5):自監督度量,圖像的表示與當前批處理中任何其他圖像的表示更相似的情況的比率。即使在沒有標記樣本的情況下,自我監督度量也可以用於超參數調優。
-
線性探測精度:線性探測是評價自我監督分類器的流行指標。它是作爲在編碼器的特徵之上訓練的邏輯迴歸分類器的精度計算的。在我們的例子中,這是通過在凍結的編碼器上訓練單一的 dense 層來完成的。請注意,與傳統方法不同,傳統方法是在預處理階段之後訓練分類器,在這個例子中,我們在預處理階段訓練它。這可能會略微降低它的準確性,但這樣我們就可以在訓練期間監控它的價值,這有助於實驗和調試。
另一個廣泛使用的監督度量是 KNN 精度,它是在編碼器特徵之上訓練的 KNN 分類器的精度,在本例中沒有實現。
# Define the contrastive model with model-subclassing
class ContrastiveModel(keras.Model):
def __init__(self):
super().__init__()
self.temperature = temperature
self.contrastive_augmenter = get_augmenter(**contrastive_augmentation)
self.classification_augmenter = get_augmenter(**classification_augmentation)
self.encoder = get_encoder()
# Non-linear MLP as projection head
self.projection_head = keras.Sequential(
[
keras.Input(shape=(width,)),
layers.Dense(width, activation="relu"),
layers.Dense(width),
],
,
)
# Single dense layer for linear probing
self.linear_probe = keras.Sequential(
[layers.Input(shape=(width,)), layers.Dense(10)],
)
self.encoder.summary()
self.projection_head.summary()
self.linear_probe.summary()
def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
super().compile(**kwargs)
self.contrastive_optimizer = contrastive_optimizer
self.probe_optimizer = probe_optimizer
# self.contrastive_loss will be defined as a method
self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
self.contrastive_loss_tracker = keras.metrics.Mean()
self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy(
)
self.probe_loss_tracker = keras.metrics.Mean()
self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy()
@property
def metrics(self):
return [
self.contrastive_loss_tracker,
self.contrastive_accuracy,
self.probe_loss_tracker,
self.probe_accuracy,
]
def contrastive_loss(self, projections_1, projections_2):
# InfoNCE loss (information noise-contrastive estimation)
# NT-Xent loss (normalized temperature-scaled cross entropy)
# Cosine similarity: the dot product of the l2-normalized feature vectors
projections_1 = tf.math.l2_normalize(projections_1, axis=1)
projections_2 = tf.math.l2_normalize(projections_2, axis=1)
similarities = (
tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature
)
# The similarity between the representations of two augmented views of the
# same image should be higher than their similarity with other views
batch_size = tf.shape(projections_1)[0]
contrastive_labels = tf.range(batch_size)
self.contrastive_accuracy.update_state(contrastive_labels, similarities)
self.contrastive_accuracy.update_state(
contrastive_labels, tf.transpose(similarities)
)
# The temperature-scaled similarities are used as logits for cross-entropy
# a symmetrized version of the loss is used here
loss_1_2 = keras.losses.sparse_categorical_crossentropy(
contrastive_labels, similarities, from_logits=True
)
loss_2_1 = keras.losses.sparse_categorical_crossentropy(
contrastive_labels, tf.transpose(similarities), from_logits=True
)
return (loss_1_2 + loss_2_1) / 2
def train_step(self, data):
(unlabeled_images, _), (labeled_images, labels) = data
# Both labeled and unlabeled images are used, without labels
images = tf.concat((unlabeled_images, labeled_images), axis=0)
# Each image is augmented twice, differently
augmented_images_1 = self.contrastive_augmenter(images)
augmented_images_2 = self.contrastive_augmenter(images)
with tf.GradientTape() as tape:
features_1 = self.encoder(augmented_images_1)
features_2 = self.encoder(augmented_images_2)
# The representations are passed through a projection mlp
projections_1 = self.projection_head(features_1)
projections_2 = self.projection_head(features_2)
contrastive_loss = self.contrastive_loss(projections_1, projections_2)
gradients = tape.gradient(
contrastive_loss,
self.encoder.trainable_weights + self.projection_head.trainable_weights,
)
self.contrastive_optimizer.apply_gradients(
zip(
gradients,
self.encoder.trainable_weights + self.projection_head.trainable_weights,
)
)
self.contrastive_loss_tracker.update_state(contrastive_loss)
# Labels are only used in evalutation for an on-the-fly logistic regression
preprocessed_images = self.classification_augmenter(labeled_images)
with tf.GradientTape() as tape:
features = self.encoder(preprocessed_images)
class_logits = self.linear_probe(features)
probe_loss = self.probe_loss(labels, class_logits)
gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
self.probe_optimizer.apply_gradients(
zip(gradients, self.linear_probe.trainable_weights)
)
self.probe_loss_tracker.update_state(probe_loss)
self.probe_accuracy.update_state(labels, class_logits)
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
labeled_images, labels = data
# For testing the components are used with a training=False flag
preprocessed_images = self.classification_augmenter(
labeled_images, training=False
)
features = self.encoder(preprocessed_images, training=False)
class_logits = self.linear_probe(features, training=False)
probe_loss = self.probe_loss(labels, class_logits)
self.probe_loss_tracker.update_state(probe_loss)
self.probe_accuracy.update_state(labels, class_logits)
# Only the probe metrics are logged at test time
return {m.name: m.result() for m in self.metrics[2:]}
# Contrastive pretraining
pretraining_model = ContrastiveModel()
pretraining_model.compile(
contrastive_optimizer=keras.optimizers.Adam(),
probe_optimizer=keras.optimizers.Adam(),
)
pretraining_history = pretraining_model.fit(
train_dataset, epochs=num_epochs, validation_data=test_dataset
)
print(
"Maximal validation accuracy: {:.2f}%".format(
max(pretraining_history.history["val_p_acc"]) * 100
)
)
Model: "encoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_4 (Conv2D) (None, 47, 47, 128) 3584
_________________________________________________________________
conv2d_5 (Conv2D) (None, 23, 23, 128) 147584
_________________________________________________________________
conv2d_6 (Conv2D) (None, 11, 11, 128) 147584
_________________________________________________________________
conv2d_7 (Conv2D) (None, 5, 5, 128) 147584
_________________________________________________________________
flatten_1 (Flatten) (None, 3200) 0
_________________________________________________________________
dense_2 (Dense) (None, 128) 409728
=================================================================
Total params: 856,064
Trainable params: 856,064
Non-trainable params: 0
_________________________________________________________________
Model: "projection_head"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_3 (Dense) (None, 128) 16512
_________________________________________________________________
dense_4 (Dense) (None, 128) 16512
=================================================================
Total params: 33,024
Trainable params: 33,024
Non-trainable params: 0
_________________________________________________________________
Model: "linear_probe"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_5 (Dense) (None, 10) 1290
=================================================================
Total params: 1,290
Trainable params: 1,290
Non-trainable params: 0
_________________________________________________________________
Epoch 1/20
200/200 [==============================] - 70s 325ms/step - c_loss: 4.7788 - c_acc: 0.1340 - p_loss: 2.2030 - p_acc: 0.1922 - val_p_loss: 2.1043 - val_p_acc: 0.2540
Epoch 2/20
200/200 [==============================] - 67s 323ms/step - c_loss: 3.4836 - c_acc: 0.3047 - p_loss: 2.0159 - p_acc: 0.3030 - val_p_loss: 1.9833 - val_p_acc: 0.3120
Epoch 3/20
200/200 [==============================] - 65s 322ms/step - c_loss: 2.9157 - c_acc: 0.4187 - p_loss: 1.8896 - p_acc: 0.3598 - val_p_loss: 1.8621 - val_p_acc: 0.3556
Epoch 4/20
200/200 [==============================] - 67s 322ms/step - c_loss: 2.5837 - c_acc: 0.4867 - p_loss: 1.7965 - p_acc: 0.3912 - val_p_loss: 1.7400 - val_p_acc: 0.4006
Epoch 5/20
200/200 [==============================] - 67s 322ms/step - c_loss: 2.3462 - c_acc: 0.5403 - p_loss: 1.6961 - p_acc: 0.4138 - val_p_loss: 1.6655 - val_p_acc: 0.4190
Epoch 6/20
200/200 [==============================] - 65s 321ms/step - c_loss: 2.2214 - c_acc: 0.5714 - p_loss: 1.6325 - p_acc: 0.4322 - val_p_loss: 1.6242 - val_p_acc: 0.4366
Epoch 7/20
200/200 [==============================] - 67s 322ms/step - c_loss: 2.0618 - c_acc: 0.6098 - p_loss: 1.5793 - p_acc: 0.4470 - val_p_loss: 1.5348 - val_p_acc: 0.4663
Epoch 8/20
200/200 [==============================] - 65s 322ms/step - c_loss: 1.9532 - c_acc: 0.6360 - p_loss: 1.5173 - p_acc: 0.4652 - val_p_loss: 1.5248 - val_p_acc: 0.4700
Epoch 9/20
200/200 [==============================] - 65s 322ms/step - c_loss: 1.8487 - c_acc: 0.6602 - p_loss: 1.4631 - p_acc: 0.4798 - val_p_loss: 1.4587 - val_p_acc: 0.4905
Epoch 10/20
200/200 [==============================] - 65s 322ms/step - c_loss: 1.7837 - c_acc: 0.6767 - p_loss: 1.4310 - p_acc: 0.4992 - val_p_loss: 1.4265 - val_p_acc: 0.4924
Epoch 11/20
200/200 [==============================] - 65s 321ms/step - c_loss: 1.7133 - c_acc: 0.6955 - p_loss: 1.3764 - p_acc: 0.5090 - val_p_loss: 1.3663 - val_p_acc: 0.5169
Epoch 12/20
200/200 [==============================] - 66s 322ms/step - c_loss: 1.6655 - c_acc: 0.7064 - p_loss: 1.3511 - p_acc: 0.5140 - val_p_loss: 1.3779 - val_p_acc: 0.5071
Epoch 13/20
200/200 [==============================] - 67s 322ms/step - c_loss: 1.6110 - c_acc: 0.7198 - p_loss: 1.3182 - p_acc: 0.5282 - val_p_loss: 1.3259 - val_p_acc: 0.5303
Epoch 14/20
200/200 [==============================] - 66s 321ms/step - c_loss: 1.5727 - c_acc: 0.7312 - p_loss: 1.2965 - p_acc: 0.5308 - val_p_loss: 1.2858 - val_p_acc: 0.5422
Epoch 15/20
200/200 [==============================] - 67s 322ms/step - c_loss: 1.5477 - c_acc: 0.7361 - p_loss: 1.2751 - p_acc: 0.5432 - val_p_loss: 1.2795 - val_p_acc: 0.5472
Epoch 16/20
200/200 [==============================] - 65s 321ms/step - c_loss: 1.5127 - c_acc: 0.7448 - p_loss: 1.2562 - p_acc: 0.5498 - val_p_loss: 1.2731 - val_p_acc: 0.5461
Epoch 17/20
200/200 [==============================] - 67s 321ms/step - c_loss: 1.4811 - c_acc: 0.7517 - p_loss: 1.2306 - p_acc: 0.5574 - val_p_loss: 1.2439 - val_p_acc: 0.5630
Epoch 18/20
200/200 [==============================] - 67s 321ms/step - c_loss: 1.4598 - c_acc: 0.7576 - p_loss: 1.2215 - p_acc: 0.5544 - val_p_loss: 1.2352 - val_p_acc: 0.5623
Epoch 19/20
200/200 [==============================] - 65s 321ms/step - c_loss: 1.4349 - c_acc: 0.7631 - p_loss: 1.2161 - p_acc: 0.5662 - val_p_loss: 1.2670 - val_p_acc: 0.5479
Epoch 20/20
200/200 [==============================] - 66s 321ms/step - c_loss: 1.4159 - c_acc: 0.7691 - p_loss: 1.2044 - p_acc: 0.5656 - val_p_loss: 1.2204 - val_p_acc: 0.5624
Maximal validation accuracy: 56.30%
對預訓練編碼器進行有監督的微調
然後,通過在有標記的樣本上附加一個隨機初始化的完全連接的分類層,對編碼器進行微調。
# Supervised finetuning of the pretrained encoder
finetuning_model = keras.Sequential(
[
layers.Input(shape=(image_size, image_size, image_channels)),
get_augmenter(**classification_augmentation),
pretraining_model.encoder,
layers.Dense(10),
],
,
)
finetuning_model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
finetuning_history = finetuning_model.fit(
labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
)
print(
"Maximal validation accuracy: {:.2f}%".format(
max(finetuning_history.history["val_acc"]) * 100
)
)
Epoch 1/20
200/200 [==============================] - 4s 17ms/step - loss: 1.9942 - acc: 0.2554 - val_loss: 1.4278 - val_acc: 0.4647
Epoch 2/20
200/200 [==============================] - 3s 16ms/step - loss: 1.5209 - acc: 0.4373 - val_loss: 1.3119 - val_acc: 0.5170
Epoch 3/20
200/200 [==============================] - 3s 17ms/step - loss: 1.3210 - acc: 0.5132 - val_loss: 1.2328 - val_acc: 0.5529
Epoch 4/20
200/200 [==============================] - 3s 17ms/step - loss: 1.1932 - acc: 0.5603 - val_loss: 1.1328 - val_acc: 0.5872
Epoch 5/20
200/200 [==============================] - 3s 17ms/step - loss: 1.1217 - acc: 0.5984 - val_loss: 1.1508 - val_acc: 0.5906
Epoch 6/20
200/200 [==============================] - 3s 16ms/step - loss: 1.0665 - acc: 0.6176 - val_loss: 1.2544 - val_acc: 0.5753
Epoch 7/20
200/200 [==============================] - 3s 16ms/step - loss: 0.9890 - acc: 0.6510 - val_loss: 1.0107 - val_acc: 0.6409
Epoch 8/20
200/200 [==============================] - 3s 16ms/step - loss: 0.9775 - acc: 0.6468 - val_loss: 1.0907 - val_acc: 0.6150
Epoch 9/20
200/200 [==============================] - 3s 17ms/step - loss: 0.9105 - acc: 0.6736 - val_loss: 1.1057 - val_acc: 0.6183
Epoch 10/20
200/200 [==============================] - 3s 17ms/step - loss: 0.8658 - acc: 0.6895 - val_loss: 1.1794 - val_acc: 0.5938
Epoch 11/20
200/200 [==============================] - 3s 17ms/step - loss: 0.8503 - acc: 0.6946 - val_loss: 1.0764 - val_acc: 0.6325
Epoch 12/20
200/200 [==============================] - 3s 17ms/step - loss: 0.7973 - acc: 0.7193 - val_loss: 1.0065 - val_acc: 0.6561
Epoch 13/20
200/200 [==============================] - 3s 16ms/step - loss: 0.7516 - acc: 0.7319 - val_loss: 1.0955 - val_acc: 0.6345
Epoch 14/20
200/200 [==============================] - 3s 16ms/step - loss: 0.7504 - acc: 0.7406 - val_loss: 1.1041 - val_acc: 0.6386
Epoch 15/20
200/200 [==============================] - 3s 16ms/step - loss: 0.7419 - acc: 0.7324 - val_loss: 1.0680 - val_acc: 0.6492
Epoch 16/20
200/200 [==============================] - 3s 17ms/step - loss: 0.7318 - acc: 0.7265 - val_loss: 1.1635 - val_acc: 0.6313
Epoch 17/20
200/200 [==============================] - 3s 17ms/step - loss: 0.6904 - acc: 0.7505 - val_loss: 1.0826 - val_acc: 0.6503
Epoch 18/20
200/200 [==============================] - 3s 17ms/step - loss: 0.6389 - acc: 0.7714 - val_loss: 1.1260 - val_acc: 0.6364
Epoch 19/20
200/200 [==============================] - 3s 16ms/step - loss: 0.6355 - acc: 0.7829 - val_loss: 1.0750 - val_acc: 0.6554
Epoch 20/20
200/200 [==============================] - 3s 17ms/step - loss: 0.6279 - acc: 0.7758 - val_loss: 1.0465 - val_acc: 0.6604
Maximal validation accuracy: 66.04%
和基線進行對比
# The classification accuracies of the baseline and the pretraining + finetuning process:
def plot_training_curves(pretraining_history, finetuning_history, baseline_history):
for metric_key, metric_name in zip(["acc", "loss"], ["accuracy", "loss"]):
plt.figure(figsize=(8, 5), dpi=100)
plt.plot(
baseline_history.history[f"val_{metric_key}"], label="supervised baseline"
)
plt.plot(
pretraining_history.history[f"val_p_{metric_key}"],
label="self-supervised pretraining",
)
plt.plot(
finetuning_history.history[f"val_{metric_key}"],
label="supervised finetuning",
)
plt.legend()
plt.title(f"Classification {metric_name} during training")
plt.xlabel("epochs")
plt.ylabel(f"validation {metric_name}")
plot_training_curves(pretraining_history, finetuning_history, baseline_history)
通過對比訓練曲線,我們可以看到,使用對比預訓練模型時,可以達到更高的驗證精度,同時驗證損失更低,這意味着預訓練後的網絡在看到少量標記樣本時能夠更好地泛化。
進一步的提升
結構
原論文中的實驗表明,增加模型的寬度和深度可以比監督學習提高更高的性能。此外,使用 ResNet-50 編碼器在文獻中是相當標準的。但是請記住,更強大的模型不僅會增加訓練時間,還會需要更多的內存,並限制你可以使用的最大批處理大小。
已經被報道使用 BatchNorm 層有時會降低性能,因爲它引入了樣本之間的批內依賴性,這就是爲什麼我在本例中沒有使用它們的原因。然而,在我的實驗中,使用 BatchNorm,特別是在投影頭中,可以提高性能。
超參數
本例中使用的超參數已爲此任務和體系結構進行了手動調優。因此,在不改變它們的情況下,只能從進一步的超參數調優中獲得邊際增益。
然而,對於不同的任務或模型體系結構,這些都需要調優,所以下面是我對其中最重要的部分的註釋:
-
批大小:由於目標可以解釋爲對一批圖像的分類 (鬆散地說),批大小實際上是一個比通常更重要的超參數。越高越好。
-
溫度:溫度定義了交叉熵損失中使用的 softmax 分佈的 “柔軟度”,是一個重要的超參數。數值越低,對比精度越高。最近的一個技巧是瞭解溫度的值,這可以通過將其定義爲 tf.Variable 並應用梯度來實現。儘管這提供了一個很好的基線值,但在我的實驗中,學習到的溫度略低於最優值,因爲它是相對於對比損失進行優化的,而對比損失並不是表徵質量的完美代理。
-
圖像增強強度:在預訓練期間,較強的增強會增加任務的難度,但在一個點後,過於強的增強會降低性能。在微調過程中,較強的增強會減少過擬合,而根據我的經驗,過強的增強會降低預訓練的性能增益。整個數據增強管道可以看作是算法的一個重要超參數,Keras 中其他自定義圖像增強層的實現可以在:https://github.com/beresandras/contrastive-classification-keras 中找到。
-
學習率策略:這裏使用常數策略上,但在文獻中使用餘弦衰減策略比較常見,可以進一步提高性能。
-
優化器:在這個例子中使用 Adam,因爲它提供了良好的性能與默認參數。使用動量的 SGD 需要更多的調優,但是它可以略微提高性能。
英文原文:https://keras.io/examples/vision/semisupervised_simclr/
本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源:https://mp.weixin.qq.com/s/Up9d3ry7te5hRmssBGU75g