TorchMetrics:PyTorch 的指標度量庫

作者:PyTorch Lightning team

編譯:ronghuaiyang 來源:AI 公園

導讀

非常簡單實用的 PyTorch 模型的分佈式指標度量庫,配合 PyTorch Lighting 實用更加方便。

找出你需要評估的指標是深度學習的關鍵。有各種各樣的指標,我們可以評估 ML 算法的性能。TorchMetrics 是一個 PyTorch 度量的實現的集合,是 PyTorch Lightning 高性能深度學習的框架的一部分。在本文中,我們將介紹如何使用 TorchMetrics 評估你的深度學習模型,甚至使用一個簡單易用的 API 創建你自己的度量。

什麼是 TorchMetrics?

TorchMetrics 是一個開源的 PyTorch 原生的函數和度量模塊的集合,用於簡單的性能評估。你可以使用開箱即用的實現來實現常見的指標,如準確性,召回率,精度,AUROC, RMSE, R² 等,或者創建你自己的指標。我們目前支持超過 25 個指標,並不斷增加更多的通用任務和特定領域的標準 (目標檢測,NLP 等)。

TorchMetrics 最初是作爲 Pytorch Lightning (PL) 的一部分創建的,被設計爲分佈式硬件兼容,並在默認情況下與 DistributedDataParalel(DDP) 一起工作。所有指標都在 cpu 和 gpu 上經過嚴格測試。

使用 TorchMetrics

安裝

這個包可以通過以下方式從 PyPI 簡單安裝:

pip install torchmetrics

或者直接從 GitHub 倉庫的源代碼安裝:

# with git
pip install git+https://github.com/PytorchLightning/metrics.git@master

函數形式的 metrics

類似於 torch.nn,大多數度量指標都有基於模塊和函數的版本。函數版本實現了計算每個度量所需的基本操作。它們是作爲輸入的簡單的 python 函數。並返回相應的 torch.tensor 的指標。下面的代碼片段展示了一個使用函數接口計算精度的簡單示例:

模塊形式的 metrics


幾乎所有函數 metrics 都有一個對應的基於模塊的 metrics,該度量將其稱爲底層的函數等價模塊。基於模塊的度量的特點是有一個或多個內部度量狀態 (類似於 PyTorch 模塊的參數),允許它們提供額外的功能:

下面的代碼展示瞭如何使用基於模塊的接口:

每次調用度量的 forward 函數時,我們同時計算當前看到的一批數據上的度量值,並更新內部度量狀態,以跟蹤到目前爲止看到的所有數據。內部狀態需要在不同時期之間重置,不應該在訓練、驗證和測試之間混合。因此我們強烈建議按如下方式重新初始化度量:

Lightning 中使用 TorchMetrics

=============================

下面的例子展示瞭如何在你的 LightningModule 中使用 metric :

雖然 TorchMetrics 被構建爲與原生的 PyTorch 一起使用,但 TorchMetrics 與 Lightning 一起使用提供了額外的好處:

Lightning 的轉換


已經熟悉 Lightning 的 metric 接口的用戶應該能夠輕鬆地適應 TorchMetrics。簡單地替換:

from pytorchlightning import metrics

with:

import torchmetrics

注意,在 1.3 版本之前,metrics 將是 PyTorchLightning 的一部分,但不再接收任何更新。我們強烈建議用戶切換到 TorchMetrics,以得到我們可能實現的所有的 bug 修復和增強。

實現自己的 metrics

如果你想使用一個還不被支持的指標,你可以使用 TorchMetrics 的 API 來實現你自己的自定義指標,只需子類化torchmetrics.Metric並實現以下方法:

  1. __init__():每個狀態變量都應該使用self.add_state(…)調用。

  2. update():任何需要更新內部度量狀態的代碼。

  3. compute():從度量值的狀態計算一個最終值。

例子:均方根誤差

均方根誤差是一個很好的例子,說明了爲什麼許多度量計算需要劃分爲兩個函數。定義爲:

爲了正確地計算 RMSE,我們需要兩個度量狀態:sum_squared_error來跟蹤目標 y 和預測 y 之間的平方誤差,以及n_observations來知道我們有多少觀測結果。

因爲 sqrt(a+b) != sqrt(a) + sqrt(b),我們不能把這個度量實現爲每個 batch 計算的 RMSE 分數的簡單平均值,而是需要實現更新步驟中需要在平方根之前發生的所有邏輯,以及在 compute 步驟中需要實現剩餘的邏輯。

爲你的模型選擇正確的度量

選擇正確的度量對於確定你的模型是否按照應該的方式運行,或者是否有什麼地方出了問題非常重要。

預測冠狀病毒

假設你的任務是建立一個分類網絡,可以通過一套非侵入性測量來確定患者是否是冠狀病毒陽性。你會得到數千份觀察報告,並使用你最喜歡的網絡架構,優化以正確識別哪些患者感染了冠狀病毒。這種模式可用於確保檢測呈陽性的患者被隔離,以避免傳播病毒並迅速得到治療。

爲了評估你的模型,你計算了 4 個指標:準確性、混淆矩陣、精確度和召回率。你得到了以下結果:

準確率: 99.9%

混淆矩陣

精確率: 1.0

召回率:0.28

評估得分

你怎麼看?這個模型足夠好嗎?讓我們更深入地瞭解這些指標的含義。在分類中,準確率是指我們的模型得到正確預測的比例。

我們的模型得到了非常高的準確率:99.9%。看來網絡正在做你要求它做的事情,你可以準確地檢測到患者是否感染了冠狀病毒。

對於二元分類,另一個有用的度量是混淆矩陣,這給了我們下面的真、假陽性和陰性的組合。

我們可以從混淆矩陣中快速確定兩件事:

從準確性來看,這個模型似乎表現得很好,但考慮到混淆矩陣,我們發現這個模型過於專注於預測陰性患者,而未能預測陽性患者。在這種設置下,它應該清楚正確識別新冠患者和正確識別非新冠患者之間的巨大的區別,正確識別患者將確保患者得到早期治療,最重要的是隔離,不要傳染給別人。

爲什麼準確率指標沒有顯示出模型有什麼問題?準確率捕獲了整體性能,以正確地預測所有類,在這種情況下,我們感興趣的是捕獲我們預測的 ground truth 的情況有多好。因此,你可以將注意力轉向精確率和召回率。

精確率定義爲實際正確的正樣本的比例。

其中 TP 和 FP 分別表示 true p positive 個數,false positive 個數。一個有 0 個誤報的模型的精確率爲 1.0,而一個模型輸出的結果都是陽性,而實際上都是假的模型的精度分數爲 0。

Recall 定義爲真實的陽性被正確識別的比例。

其中 TP 和 FN 分別表示 true positives 數,false negatives 數。類似地,如果沒有錯誤否定,一個模型的召回分數將爲 1.0。從定義上我們可以得出結論,精確率聚焦於在不能識別所有假陽性的 “成本” 上,而召回率聚焦在不能識別所有假陰性的 “成本” 上。因爲我們在這裏感興趣的是假陰性,所以我們應該在 recall metric 下重新評估我們的模型,現在我們得到了 0.28 的分數。現在,你已經量化了模型的性能不佳,並且在訓練機器學習算法時可能需要處理數據集中存在的巨大類不平衡。

這個小例子展示了選擇正確度量來評估機器學習算法的重要性。通常,建議使用一組度量標準來評估算法,因爲它們都關注數據和模型預測的不同方面。

英文原文:https://pytorch-lightning.medium.com/torchmetrics-pytorch-metrics-built-to-scale-7091b1bec919

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/pshWcxiUgkCeIQFNR8byAg