有監督對比學習的一個簡單的例子

作者:Dimitre Oliveira

編譯:ronghuaiyang

導讀

使用有監督對比學習來進行木薯葉病害識別。

論文鏈接:https://arxiv.org/abs/2004.11362

監督對比學習 (Prannay Khosla 等人) 是一種訓練方法,它在分類任務上優於使用交叉熵的監督訓練。

這個想法是,使用監督對比學習 (SCL) 的訓練模型可以使模型編碼器從樣本學習更好的類表示,這應該導致更好的泛化,並對於圖像和標籤的錯誤更具魯棒性。

在本文中,你將瞭解什麼是監督對比學習,以及監督對比學習是如何工作的,你會看到代碼實現、一個應用程序的例子,最後將看到 SCL 和常規交叉熵之間的比較。

簡而言之,SCL 就是這樣工作的:

在嵌入空間中將屬於同一類的聚類點聚在一起,同時將來自不同類的樣本簇分離。

有許多對比學習方法,如 " 監督對比學習 "," 自監督對比學習 "," SimCLR " 等,它們的比對部分都是共同的,它們學習來自一個域的樣本和來自另一個域的樣本的差別,但 SCL 以監督的方式利用標籤信息完成這項任務。

不同的訓練方法的結構

本質上,用監督對比學習對分類模型進行訓練分爲兩個階段:

  1. 訓練編碼器,學習生成輸入圖像的向量表示,這樣,同類別圖像的表示將比不同類別圖像的表示更加相似。

  2. 在參數凍結的編碼器上訓練一個分類器。

例子

我們將把監督比較學習應用於 Kaggle 競賽的數據集 (Cassava Leaf Disease Classification),目的是將木薯葉的圖像分類爲 5 類:

0: Cassava Bacterial Blight (CBB)
1: Cassava Brown Streak Disease (CBSD)
2: Cassava Green Mottle (CGM)
3: Cassava Mosaic Disease (CMD)
4: Healthy

我們有四種疾病和一種健康的葉子,下面是一些圖像樣本:

來自比賽的木薯葉圖像樣本

數據有 21397 圖像用於訓練,大約有 15000 圖像用於測試集。

實驗設置

你可以在這裏查看:https://www.kaggle.com/dimitreoliveira/cassava-leaf-supervised-contrastive-learning

通常,對比學習方法能更好地工作,如果每個訓練一個 batch 都有每個類的樣本,這將有助於編碼器學會對比不同域之間的差別,這意味着需要使用一個大的 batch size,在這種情況下,我已經對每個類進行了過採樣,所以每個 batch 的樣本中每個類樣本的概率大致相同。

數據集中的類別分佈,過採樣之後

數據增強通常有助於計算機視覺任務,在我的實驗中,我也看到了數據增強的改進,這裏我使用剪切,旋轉,翻轉,作物,剪切,飽和度,對比度和亮度的變化,它可能看起來很多,但圖像沒有和原始圖像有太大不同。

增強後的數據樣本

現在我們可以看看代碼了

編碼器

我們的編碼器將是一個 “EfficientNet B3”,但是在編碼器的頂部有一個平均池化層,這個池化層將輸出一個 2048 大小的向量,稍後它將用於檢查編碼器學習到的表示。

def encoder_fn(input_shape):
    inputs = L.Input(shape=input_shape, name=’inputs’)
    base_model = efn.EfficientNetB3(input_tensor=inputs, 
                                    include_top=False,
                                    weights=’noisy-student’, 
                                    pooling=’avg’)
 
    model = Model(inputs=inputs, outputs=base_model.outputs)
    return model

投影頭

投影頭位於編碼器的頂部,負責將編碼器嵌入層的輸出投影到更小的尺寸中,在我們的例子中,它將 2048 維的編碼器投影到 128 維的向量中。

def add_projection_head(input_shape, encoder):
    inputs = L.Input(shape=input_shape, name='inputs')
    features = encoder(inputs)
    outputs = L.Dense(128, activation='relu', 
                      name='projection_head', 
                      dtype='float32')(features)
    
    model = Model(inputs=inputs, outputs=outputs)
    return model

分類頭

分類器頭用於的可選的第二階段訓練,在 SCL 訓練階段之後,我們可以去掉投影頭,把這個分類器頭加到編碼器上,並使用常規的交叉熵損失來 finetune 模型,這樣做的時候,需要凍結編碼器層。

def classifier_fn(input_shape, N_CLASSES, encoder, trainable=False):
    for layer in encoder.layers:
        layer.trainable = trainable
        
    inputs = L.Input(shape=input_shape, name='inputs')
    
    features = encoder(inputs)
    features = L.Dropout(.5)(features)
    features = L.Dense(1000, activation='relu')(features)
    features = L.Dropout(.5)(features)
    outputs = L.Dense(N_CLASSES, activation='softmax', 
                      name='outputs'dtype='float32')(features)

    model = Model(inputs=inputs, outputs=outputs)
    return model

監督對比學習損失

這是 SCL 損失的代碼實現,這裏唯一的參數是 temperature,“0.1” 是默認值,但它可以調整,較大的 temperatures 可以導致類更分離,但較小的 temperatures 有益於較長的訓練。

class SupervisedContrastiveLoss(losses.Loss):
    def __init__(self, temperature=0.1, name=None):
        super(SupervisedContrastiveLoss, self).__init__(name=name)
        self.temperature = temperature

    def __call__(self, labels, ft_vectors, sample_weight=None):
        # Normalize feature vectors
        ft_vec_normalized = tf.math.l2_normalize(ft_vectors, axis=1)
        # Compute logits
        logits = tf.divide(
            tf.matmul(ft_vec_normalized, 
                      tf.transpose(ft_vec_normalized)
            ), temperature
        )
        return tfa.losses.npairs_loss(tf.squeeze(labels), logits)

訓練

我將跳過 Tensorflow 樣板訓練代碼,因爲它非常標準,但是你可以在這裏:https://www.kaggle.com/dimitreoliveira/cassava-leaf-supervised-contrastive-learning/notebook#Training-(supervised-contrastive-learning 查看完整的代碼。

第一個訓練步驟 (編碼器 + 投影頭)

第一階段的訓練是用編碼器 + 投影頭,使用有監督對比學習損失。

構建模型

with strategy.scope()# Inside a strategy because I am using a TPU
  encoder = encoder_fn((None, None, CHANNELS)) # Get the encoder
  encoder_proj = add_projection_head((None, None, CHANNELS),encoder)
  # Add the projection head to the encoderencoder_proj.compile(optimizer=optimizers.Adam(lr=3e-4), 
                    loss=SupervisedContrastiveLoss(temperature=0.1))

訓練

model.fit(x=get_dataset(TRAIN_FILENAMES, 
                        repeated=True, 
                        augment=True), 
          validation_data=get_dataset(VALID_FILENAMES, 
                                      ordered=True), 
          steps_per_epoch=100, 
          epochs=10)

第二個訓練步驟 (編碼器 + 分類頭)

對於訓練的第二階段,我們刪除投影頭,並在編碼器的頂部添加分類器頭,現在該編碼器已經訓練了權值。對於這一步,我們可以使用常規的交叉熵損失,像往常一樣訓練模型。

構建模型

with strategy.scope():
    model = classifier_fn((None, None, CHANNELS), N_CLASSES, 
                          encoder, # trained encoder
                          trainable=False) # with frozen weights    model.compile(optimizer=optimizers.Adam(lr=3e-4),
                  loss=losses.SparseCategoricalCrossentropy(), 
                  metrics=[metrics.SparseCategoricalAccuracy()])

訓練

和之前幾乎一樣

model.fit(x=get_dataset(TRAIN_FILENAMES, 
                        repeated=True, 
                        augment=True), 
          validation_data=get_dataset(VALID_FILENAMES, 
                                      ordered=True), 
          steps_per_epoch=100, 
          epochs=10)

可視化輸出向量

評估編碼器的學習表示的一種有趣的方法是可視化特徵嵌入的輸出,在我們的例子中,它是編碼器的最後一層,即平均池化層。在這裏,我們將比較用 SCL 訓練的模型和另一個用常規交叉熵訓練的模型,你可以在:https://www.kaggle.com/dimitreoliveira/cassava-leaf-supervised-contrastive-learning 中看到完整的訓練。可視化是通過在驗證數據集的嵌入輸出上應用 t-SNE 生成的。

交叉熵的嵌入

對使用交叉熵訓練的模型嵌入進行可視化

監督對比學習的嵌入

使用 SCL 訓練出的模型的嵌入的可視化。

我們可以看到,兩種模型在對每個類進行樣本聚類的時候似乎都可以做的不錯,但看下 SCL 模型訓練出來的嵌入,每個類的簇相互之間的距離要更遠,這就是對比學習的效果。我們也可以認爲,這種行爲將導致更好的泛化,因爲類的判別邊界會更清晰、如果去嘗試畫一下類別之間的邊界,就可以得到一個很直觀的理解。

總結

我們看到,使用監督對比學習方法的訓練既容易實現又有效,它可以帶來更好的準確性和更好的類表示,這反過來也可以產生更健壯的模型,能夠更好地泛化。

英文原文:https://pub.towardsai.net/supervised-contrastive-learning-for-cassava-leaf-disease-classification-9dd47779a966

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/qnG0YLf0yjs4aT9URRMDyw