Github1-3 萬星,迅猛發展的 JAX 對比 TensorFlow、PyTorch
=================================================================================================================================================
機器之心報道
JAX 是機器學習 (ML) 領域的新生力量,它有望使 ML 編程更加直觀、結構化和簡潔。
在機器學習領域,大家可能對 TensorFlow 和 PyTorch 已經耳熟能詳,但除了這兩個框架,一些新生力量也不容小覷,它就是谷歌推出的 JAX。很對研究者對其寄予厚望,希望它可以取代 TensorFlow 等衆多機器學習框架。
JAX 最初由谷歌大腦團隊的 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人發起。
目前,JAX 在 GitHub 上已累積 13.7K 星。
項目地址:https://github.com/google/jax
迅速發展的 JAX
JAX 的前身是 Autograd,其藉助 Autograd 的更新版本,並且結合了 XLA,可對 Python 程序與 NumPy 運算執行自動微分,支持循環、分支、遞歸、閉包函數求導,也可以求三階導數;依賴於 XLA,JAX 可以在 GPU 和 TPU 上編譯和運行 NumPy 程序;通過 grad,可以支持自動模式反向傳播和正向傳播,且二者可以任意組合成任何順序。
開發 JAX 的出發點是什麼?說到這,就不得不提 NumPy。NumPy 是 Python 中的一個基礎數值運算庫,被廣泛使用。但是 numpy 不支持 GPU 或其他硬件加速器,也沒有對反向傳播的內置支持,此外,Python 本身的速度限制阻礙了 NumPy 使用,所以少有研究者在生產環境下直接用 numpy 訓練或部署深度學習模型。
在此情況下,出現了衆多的深度學習框架,如 PyTorch、TensorFlow 等。但是 numpy 具有靈活、調試方便、API 穩定等獨特的優勢。而 JAX 的主要出發點就是將 numpy 的以上優勢與硬件加速結合。
目前,基於 JAX 已有很多優秀的開源項目,如谷歌的神經網絡庫團隊開發了 Haiku,這是一個面向 Jax 的深度學習代碼庫,通過 Haiku,用戶可以在 Jax 上進行面向對象開發;又比如 RLax,這是一個基於 Jax 的強化學習庫,用戶使用 RLax 就能進行 Q-learning 模型的搭建和訓練;此外還包括基於 JAX 的深度學習庫 JAXnet,該庫一行代碼就能定義計算圖、可進行 GPU 加速。可以說,在過去幾年中,JAX 掀起了深度學習研究的風暴,推動了科學研究迅速發展。
JAX 的安裝
如何使用 JAX 呢?首先你需要在 Python 環境或 Google colab 中安裝 JAX,使用 pip 進行安裝:
pip install --upgrade jax jaxlib
注意,上述安裝方式只是支持在 CPU 上運行,如果你想在 GPU 執行程序,首先你需要有 CUDA、cuDNN ,然後運行以下命令(確保將 jaxlib 版本映射到 CUDA 版本):
pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
現在將 JAX 與 Numpy 一起導入:
import jax
import jax.numpy as jnp
import numpy as np
JAX 的一些特性
使用 grad() 函數自動微分:這對深度學習應用非常有用,這樣就可以很容易地運行反向傳播,下面爲一個簡單的二次函數並在點 1.0 上求導的示例:
from jax import grad
def f(x):
return 3*x**2 + 2*x + 5
def f_prime(x):
return 6*x +2
grad(f)(1.0)
# DeviceArray(8., dtype=float32)
f_prime(1.0)
# 8.0
jit(Just in time) :爲了利用 XLA 的強大功能,必須將代碼編譯到 XLA 內核中。這就是 jit 發揮作用的地方。要使用 XLA 和 jit,用戶可以使用 jit() 函數或 @jit 註釋。
from jax import jit
x = np.random.rand(1000,1000)
y = jnp.array(x)
def f(x):
for _ in range(10):
x = 0.5*x + 0.1* jnp.sin(x)
return x
g = jit(f)
%timeit -n 5 -r 5 f(y).block_until_ready()
# 5 loops, best of 5: 10.8 ms per loop
%timeit -n 5 -r 5 g(y).block_until_ready()
# 5 loops, best of 5: 341 µs per loop
pmap:自動將計算分配到所有當前設備,並處理它們之間的所有通信。JAX 通過 pmap 轉換支持大規模的數據並行,從而將單個處理器無法處理的大數據進行處理。要檢查可用設備,可以運行 jax.devices():
from jax import pmap
def f(x):
return jnp.sin(x) + x**2
f(np.arange(4))
#DeviceArray([0. , 1.841471 , 4.9092975, 9.14112 ], dtype=float32)
pmap(f)(np.arange(4))
#ShardedDeviceArray([0. , 1.841471 , 4.9092975, 9.14112 ], dtype=float32)
vmap:是一種函數轉換,JAX 通過 vmap 變換提供了自動矢量化算法,大大簡化了這種類型的計算,這使得研究人員在處理新算法時無需再去處理批量化的問題。示例如下:
from jax import vmap
def f(x):
return jnp.square(x)
f(jnp.arange(10))
#DeviceArray([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)
vmap(f)(jnp.arange(10))
#DeviceArray([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)
TensorFlow vs PyTorch vs Jax
在深度學習領域有幾家巨頭公司,他們所提出的框架被廣大研究者使用。比如谷歌的 TensorFlow、Facebook 的 PyTorch、微軟的 CNTK、亞馬遜 AWS 的 MXnet 等。
每種框架都有其優缺點,選擇的時候需要根據自身需求進行選擇。
我們以 Python 中的 3 個主要深度學習框架——TensorFlow、PyTorch 和 Jax 爲例進行比較。這些框架雖然不同,但有兩個共同點:
-
它們是開源的。這意味着如果庫中存在錯誤,使用者可以在 GitHub 中發佈問題(並修復),此外你也可以在庫中添加自己的功能;
-
由於全局解釋器鎖,Python 在內部運行緩慢。所以這些框架使用 C/C++ 作爲後端來處理所有的計算和並行過程。
那麼它們的不同體現在哪些方面呢?如下表所示,爲 TensorFlow、PyTorch、JAX 三個框架的比較。
TensorFlow
TensorFlow 由谷歌開發,最初版本可追溯到 2015 年開源的 TensorFlow0.1,之後發展穩定,擁有強大的用戶羣體,成爲最受歡迎的深度學習框架。但是用戶在使用時,也暴露了 TensorFlow 缺點,例如 API 穩定性不足、靜態計算圖編程複雜等缺陷。因此在 TensorFlow2.0 版本,谷歌將 Keras 納入進來,成爲 tf.keras。
目前 TensorFlow 主要特點包括以下:
-
這是一個非常友好的框架,高級 API-Keras 的可用性使得模型層定義、損失函數和模型創建變得非常容易;
-
TensorFlow2.0 帶有 Eager Execution(動態圖機制),這使得該庫更加用戶友好,並且是對以前版本的重大升級;
-
Keras 這種高級接口有一定的缺點,由於 TensorFlow 抽象了許多底層機制(只是爲了方便最終用戶),這讓研究人員在處理模型方面的自由度更小;
-
Tensorflow 提供了 TensorBoard,它實際上是 Tensorflow 可視化工具包。它允許研究者可視化損失函數、模型圖、模型分析等。
PyTorch
PyTorch(Python-Torch) 是來自 Facebook 的機器學習庫。用 TensorFlow 還是 PyTorch?在一年前,這個問題毫無爭議,研究者大部分會選擇 TensorFlow。但現在的情況大不一樣了,使用 PyTorch 的研究者越來越多。PyTorch 的一些最重要的特性包括:
-
與 TensorFlow 不同,PyTorch 使用動態類型圖,這意味着執行圖是在運行中創建的。它允許我們隨時修改和檢查圖的內部結構;
-
除了用戶友好的高級 API 之外,PyTorch 還包括精心構建的低級 API,允許對機器學習模型進行越來越多的控制。我們可以在訓練期間對模型的前向和後向傳遞進行檢查和修改輸出。這被證明對於梯度裁剪和神經風格遷移非常有效;
-
PyTorch 允許用戶擴展代碼,可以輕鬆添加新的損失函數和用戶定義的層。PyTorch 的 Autograd 模塊實現了深度學習算法中的反向傳播求導數,在 Tensor 類上的所有操作, Autograd 都能自動提供微分,簡化了手動計算導數的複雜過程;
-
PyTorch 對數據並行和 GPU 的使用具有廣泛的支持;
-
PyTorch 比 TensorFlow 更 Python 化。PyTorch 非常適合 Python 生態系統,它允許使用 Python 類調試器工具來調試 PyTorch 代碼。
**JAX **
JAX 是來自 Google 的一個相對較新的機器學習庫。它更像是一個 autograd 庫,可以區分原生的 python 和 NumPy 代碼。JAX 的一些特性主要包括:
-
正如官方網站所描述的那樣,JAX 能夠執行 Python+NumPy 程序的可組合轉換:向量化、JIT 到 GPU/TPU 等等;
-
與 PyTorch 相比,JAX 最重要的方面是如何計算梯度。在 Torch 中,圖是在前向傳遞期間創建的,梯度在後向傳遞期間計算, 另一方面,在 JAX 中,計算表示爲函數。在函數上使用 grad() 返回一個梯度函數,該函數直接計算給定輸入的函數梯度;
-
JAX 是一個 autograd 工具,不建議單獨使用。有各種基於 JAX 的機器學習庫,其中值得注意的是 ObJax、Flax 和 Elegy。由於它們都使用相同的核心並且接口只是 JAX 庫的 wrapper,因此可以將它們放在同一個 bracket 下;
-
Flax 最初是在 PyTorch 生態系統下開發的,更注重使用的靈活性。另一方面,Elegy 受 Keras 啓發。ObJAX 主要是爲以研究爲導向的目的而設計的,它更注重簡單性和可理解性。
參考鏈接:
https://www.askpython.com/python-modules/tensorflow-vs-pytorch-vs-jax
https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
https://www.zhihu.com/question/306496943/answer/557876584
本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源:https://mp.weixin.qq.com/s/TnqXcx8DSCNgnhI8cGXfRA