Github1-3 萬星,迅猛發展的 JAX 對比 TensorFlow、PyTorch

=================================================================================================================================================

機器之心報道

JAX 是機器學習 (ML) 領域的新生力量,它有望使 ML 編程更加直觀、結構化和簡潔。

在機器學習領域,大家可能對 TensorFlow 和 PyTorch 已經耳熟能詳,但除了這兩個框架,一些新生力量也不容小覷,它就是谷歌推出的 JAX。很對研究者對其寄予厚望,希望它可以取代 TensorFlow 等衆多機器學習框架。

JAX 最初由谷歌大腦團隊的 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人發起。

目前,JAX 在 GitHub 上已累積 13.7K 星。

項目地址:https://github.com/google/jax

迅速發展的 JAX

JAX 的前身是 Autograd,其藉助 Autograd 的更新版本,並且結合了 XLA,可對 Python 程序與 NumPy 運算執行自動微分,支持循環、分支、遞歸、閉包函數求導,也可以求三階導數;依賴於 XLA,JAX 可以在 GPU 和 TPU 上編譯和運行 NumPy 程序;通過 grad,可以支持自動模式反向傳播和正向傳播,且二者可以任意組合成任何順序。

開發 JAX 的出發點是什麼?說到這,就不得不提 NumPy。NumPy 是 Python 中的一個基礎數值運算庫,被廣泛使用。但是 numpy 不支持 GPU 或其他硬件加速器,也沒有對反向傳播的內置支持,此外,Python 本身的速度限制阻礙了 NumPy 使用,所以少有研究者在生產環境下直接用 numpy 訓練或部署深度學習模型。

在此情況下,出現了衆多的深度學習框架,如 PyTorch、TensorFlow 等。但是 numpy 具有靈活、調試方便、API 穩定等獨特的優勢。而 JAX 的主要出發點就是將 numpy 的以上優勢與硬件加速結合。

目前,基於 JAX 已有很多優秀的開源項目,如谷歌的神經網絡庫團隊開發了 Haiku,這是一個面向 Jax 的深度學習代碼庫,通過 Haiku,用戶可以在 Jax 上進行面向對象開發;又比如 RLax,這是一個基於 Jax 的強化學習庫,用戶使用 RLax 就能進行 Q-learning 模型的搭建和訓練;此外還包括基於 JAX 的深度學習庫 JAXnet,該庫一行代碼就能定義計算圖、可進行 GPU 加速。可以說,在過去幾年中,JAX 掀起了深度學習研究的風暴,推動了科學研究迅速發展。

JAX 的安裝

如何使用 JAX 呢?首先你需要在 Python 環境或 Google colab 中安裝 JAX,使用 pip 進行安裝:

pip install --upgrade jax jaxlib

注意,上述安裝方式只是支持在 CPU 上運行,如果你想在 GPU 執行程序,首先你需要有 CUDA、cuDNN ,然後運行以下命令(確保將 jaxlib 版本映射到 CUDA 版本):

pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

現在將 JAX 與 Numpy 一起導入:

import jax
import jax.numpy as jnp
import numpy as np

JAX 的一些特性

使用 grad() 函數自動微分:這對深度學習應用非常有用,這樣就可以很容易地運行反向傳播,下面爲一個簡單的二次函數並在點 1.0 上求導的示例:

from jax import grad
def f(x):
  return 3*x**2 + 2*x + 5
def f_prime(x):
  return 6*x +2
grad(f)(1.0)
# DeviceArray(8., dtype=float32)
f_prime(1.0)
# 8.0

jit(Just in time) :爲了利用 XLA 的強大功能,必須將代碼編譯到 XLA 內核中。這就是 jit 發揮作用的地方。要使用 XLA 和 jit,用戶可以使用 jit() 函數或 @jit 註釋。

from jax import jit
x = np.random.rand(1000,1000)
y = jnp.array(x)
def f(x):
  for _ in range(10):
      x = 0.5*x + 0.1* jnp.sin(x)
  return x
g = jit(f)
%timeit -n 5 -r 5 f(y).block_until_ready()
# 5 loops, best of 5: 10.8 ms per loop
%timeit -n 5 -r 5 g(y).block_until_ready()
# 5 loops, best of 5: 341 µs per loop

pmap:自動將計算分配到所有當前設備,並處理它們之間的所有通信。JAX 通過 pmap 轉換支持大規模的數據並行,從而將單個處理器無法處理的大數據進行處理。要檢查可用設備,可以運行 jax.devices():

from jax import pmap
def f(x):
  return jnp.sin(x) + x**2
f(np.arange(4))
#DeviceArray([0.       , 1.841471 , 4.9092975, 9.14112  ], dtype=float32)
pmap(f)(np.arange(4))
#ShardedDeviceArray([0.       , 1.841471 , 4.9092975, 9.14112  ], dtype=float32)

vmap:是一種函數轉換,JAX 通過 vmap 變換提供了自動矢量化算法,大大簡化了這種類型的計算,這使得研究人員在處理新算法時無需再去處理批量化的問題。示例如下:

from jax import vmap
def f(x):
  return jnp.square(x)
f(jnp.arange(10))
#DeviceArray([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81], dtype=int32)
vmap(f)(jnp.arange(10))
#DeviceArray([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81], dtype=int32)

TensorFlow vs PyTorch vs Jax

在深度學習領域有幾家巨頭公司,他們所提出的框架被廣大研究者使用。比如谷歌的 TensorFlow、Facebook 的 PyTorch、微軟的 CNTK、亞馬遜 AWS 的 MXnet 等。

每種框架都有其優缺點,選擇的時候需要根據自身需求進行選擇。

我們以 Python 中的 3 個主要深度學習框架——TensorFlow、PyTorch 和 Jax 爲例進行比較。這些框架雖然不同,但有兩個共同點:

那麼它們的不同體現在哪些方面呢?如下表所示,爲 TensorFlow、PyTorch、JAX 三個框架的比較。

TensorFlow

TensorFlow 由谷歌開發,最初版本可追溯到 2015 年開源的 TensorFlow0.1,之後發展穩定,擁有強大的用戶羣體,成爲最受歡迎的深度學習框架。但是用戶在使用時,也暴露了 TensorFlow 缺點,例如 API 穩定性不足、靜態計算圖編程複雜等缺陷。因此在 TensorFlow2.0 版本,谷歌將 Keras 納入進來,成爲 tf.keras。

目前 TensorFlow 主要特點包括以下:

PyTorch

PyTorch(Python-Torch) 是來自 Facebook 的機器學習庫。用 TensorFlow 還是 PyTorch?在一年前,這個問題毫無爭議,研究者大部分會選擇 TensorFlow。但現在的情況大不一樣了,使用 PyTorch 的研究者越來越多。PyTorch 的一些最重要的特性包括:

**JAX **

JAX 是來自 Google 的一個相對較新的機器學習庫。它更像是一個 autograd 庫,可以區分原生的 python 和 NumPy 代碼。JAX 的一些特性主要包括:

參考鏈接:

https://www.askpython.com/python-modules/tensorflow-vs-pytorch-vs-jax

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

https://www.zhihu.com/question/306496943/answer/557876584

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/TnqXcx8DSCNgnhI8cGXfRA