izmyonの日記

奈良の山奥で研究にいそしむ大学院生の学習記録。

JAX学習記録①ーJAX As Accelerated NumPy

Preface

 このシリーズは、JAXのドキュメントを読みながら学習した際に、ノート代わりに記事に残したものである。JAXを理解する際の要所は、XLAやjust-in-time、Asynchronous dispatchなど多岐にわたるが、この記事ではまず第一弾として、Numpy-likeなAPIで利用でき、自動微分にも対応するJAXのアクセラレートされたNumpyとしての機能(便宜的にJAX numpyと呼ぶ。)についてまとめる。

jax.readthedocs.io

目次

What’s JAX?

 JAXは、CPUだけでなく、GPU, TPUなどのアクセラレータ上でも動き、自動微分を兼ね備えたNumPyである。Autogradを用いて、関数の微分だけでなくさらにその微分を行ったり、順方向の微分だけでなく逆方向の微分も行い、さらに両者の合成も行うことができる。しかもJAXは、XLA(Accelerated Linear Algebra)を用いてNumpyのコードをGPUやTPU上でコンパイルして実行することができる。ライブラリの呼び出しはjust-in-timeでコンパイルされ実行され、さらに一つのAPIを使って独自のPython関数をXLAに最適化されたカーネルにjust-in-timeコンパイルすることもできる。また、コンパイルと自動微分は任意に構成でき、Pythonを離れなくても高度なアルゴリズムを表現し、最大限の性能を得ることができる。

 ....と、JAXの概要を述べたが、僕もXLAやjust-in-timeコンパイルについてはまだ理解できていない。これからこのブログを書き進めていきながら理解していこうと思っているので、諦めそうになった方がいても安心してほしい。これから一緒に理解していきましょう。

 個人的には、JAXの特筆すべき特徴としては以下があると考えている。

  • CPUだけでなく、GPU, TPU上でも動くが、実装がPythonで完結している。
  • XLAによる高速化を行っている。
  • デフォルトで非同期処理を行う。
  • 関数型プログラミングを意識し、関数的に設計されている。
  • 実際の数学と同じように関数を微分することができ、勾配ベクトルだけではなく、勾配関数を求めることができる。
  • Pythonは動的型付けだが、型付けに厳しめに実装されている。
  • 乱数の管理が独特で厳密である。

特に、PyTorchなどのように、勾配の計算などがbackward()といった形でブラックボックスになっていると個人的に非常に気持ち悪い。しかしJAXは、関数型プログラミングを意識して設計されていることもあり、loss関数を本当にパラメータとデータの関数として定義でき、しかもAutoGradによりその勾配の関数を求めることができて、勾配の計算をブラックボックスにせずにしかも数式を書くのと同じように実装することができる。個人的にJAXのこの特徴は本当に好き。大好き。(PyTorchのここら辺の話は、研究室の同期が書いたこの記事「PyTorchの微分はいかにして動くのか?」を参照のこと。)

 ここではまず、JAXのNumpyとの違いや、APIの設計指針などについて述べ、次章でJAX numpyの具体的な使い方について入門していく。

JAX vs. Numpy

 JAXのkey-conceptは、以下。

  • 簡便のため、JAXはNumPyにインスパイアされたインタフェースを持つ。
  • duck-typingにより、JAX arrayは多くの場合、Numpy arrayの代用になりうる。
  • NumPy arrayと違い、JAX arrayは常にimmutableである。

JAXは基本的にはNumpyをwrapし、機能を追加したものと捉えられるが、JAX arrayは常にimmutableであることには注意。 例えば、JAX arrayを以下のように変更しようとするとTypeErrorが出る。

x = jnp.arange(10)
x[0] = 10
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-7-6b90817377fe> in <module>()
      1 # JAX: immutable arrays
      2 x = jnp.arange(10)
----> 3 x[0] = 10

TypeError: '<class 'jax.interpreters.xla._DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

特定のindexにあるelementを変更したい場合には、以下のように、indexed update syntaxを使って複製する必要がある。

y = x.at[0].set(10)

このような違いは、JAXが関数型プログラミングのように、関数的に設計されていることに起因している。

JAX API layering

 JAXのAPI設計のkey-conceptは、以下。

  • jax.numpyは、numpyと似たインタフェースを提供するより高いレベルのwrapper
  • jax.laxは、より低いレベルのAPIで、より静的(strict)だがよりパワフル
  • すべてのJAX演算は、XLAの演算として実装されている。

 特に、jax.laxはjax.numpyと比べて静的(strict)であることには注意する必要がある。 例えば、以下の演算はjax.numpyでは実行できるが、jax.laxではTypeErrorが出る。

import jax.numpy as jnp
jnp.add(1, 1.0)
DeviceArray(2., dtype=float32)
from jax import lax
lax.add(1, 1.0)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-10-63245925fccf> in <module>()
      1 from jax import lax
----> 2 lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.

TypeError: add requires arguments to have the same dtypes, got int32, float32.

この場合、しっかりと型を指定して、以下のように書く。

lax.add(jnp.float32(1), 1.0)
DeviceArray(2., dtype=float32)

Getting Start with JAX numpy

 JAXの概要やJAX numpyとNumPyやjax.laxとの違い、JAXの設計指針に触れたとこで、ここから以下の記事を読んでJAX numpyの使い方をまとめていく。 ここでまず念頭におくべきは、冒頭でも述べた通り、JAXは基本的にはアクセラレータ上で動き、自動微分を兼ね備えたNumpyと考えてよいということである。

jax.readthedocs.io

JAX array

ベクトルを作成するのは、Numpyと同様で、arangeを用いて以下のようにする。

import jax
import jax.numpy as jnp

x = jnp.arange(10)
print(x)
x
[0 1 2 3 4 5 6 7 8 9]
DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

変数のタイプは、DeviceArrayである。 また、この際、GPU/TPUが使用できないと、以下の警告が出る。

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

このように出るのは、JAX numpyが基本的にはGPU/TPUでの使用を想定されているからである。 JAXの便利な特徴の一つは、同じコードで異なるバックエンドを使用することができることである。 GPU/TPUを使用可能にし(例えばGoogle ColabではRuntimeをGPUに変更し)、例えば以下のようにdot積を実行すると、コードの変更なく別のデバイス(ここではGPU)で実行することができる。

long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()
The slowest run took 7.39 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 5: 7.85 ms per loop

JAX first transformation: grad

 JAXの基本的な特徴として、「関数の変換を可能にする」、というものがある。 その中の最も基本的な変換がjax.gradである。これは、その名の通り、関数を与えると、勾配関数を返す。 例えば、以下のように二乗和を返す関数を定義する。

def sum_of_squares(x):
    return jnp.sum(x**2)

次にjax.gradを用いて勾配関数を求め、xの値を定義して勾配を計算すると、以下のようになる。

sum_of_squares_dx = jax.grad(sum_of_squares)
x = jnp.asarray([1.0, 2.0, 3.0, 4.0])
print(sum_of_squares(x))
print(sum_of_squares_dx(x))
30.0
[2. 4. 6. 8.]

このように、jax.gradは微分演算子  \nabla であると考えることができる。与えられた関数が f(x)であるとすると、 \nabla f fの勾配関数である。 したがって、jax.grad(f)は勾配を計算する関数であり、jax.grad(f)(x)は f xにおける勾配である。

これは、前述したようにTensorflowやPyTorchと大きく異なる特徴である。それらのライブラリでは、勾配はテンソル自体を使用してloss.backward()などで計算するが、JAX APIでは、loss関数は本当にパラメータとデータの関数であり、数学と同じようにその勾配を求めることができる。

微分する関数は複数の変数を持っている場合があるが、jax.gradは、デフォルトで第一引数として渡されている変数に対して微分を行う。 例えば、以下のような二乗誤差の勾配関数を求める場合には、第一引数として渡されている xに対して微分を行っている。

def sum_squared_error(x, y):
  return jnp.sum((x-y)**2)

sum_squared_error_dx = jax.grad(sum_squared_error)
y = jnp.asarray([1.1, 2.1, 3.1, 4.1])
print(sum_squared_error_dx(x, y))
[-0.20000005 -0.19999981 -0.19999981 -0.19999981]

異なる変数についてそれぞれ勾配関数を求めたい場合には、以下のようにargnumsを設定する。

jax.grad(sum_squared_error, argnums=(0, 1))(x, y)  # Find gradient wrt both x & y
(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),
 DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))

これだけ見ると、機械学習のように膨大なパラメータを持つ関数の勾配を求めたい場合には、いちいち引数を指定するのは非常に面倒なのではないかと思うかもしれないが、実用上は「pytree」というデータ構造でパラメータをを定義するため、その必要はなく、実際の実装は以下のようになる。(pytreeのついては、この記事では触れず、別の記事で書く。)

def loss_fn(params, data):
  ...

grads = jax.grad(loss_fn)(params, data_batch)

ここで、paramsは、例えばネストされたarrayの辞書であり、返されるgradsもまた、同じ構造を持つネストされたarrayの辞書となる。

Value and Grad

 あるデータを与えたときの関数の返り値と勾配の両方を欲しい場合には、value_and_grad関数を使う。

jax.value_and_grad(sum_squared_error)(x, y)
(DeviceArray(0.03999995, dtype=float32),
 DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))

上のように、value_and_grad関数は(value, grad)という形で、結果をタプルで返す。 つまり、vlue_and_gradは、関数 fおよびその変数*xsに対して、以下のように定義される。

jax.value_and_grad(f)(*xs) == (f(*xs), jax.grad(f)(*xs)) 

Auxiliary data

 勾配を計算する際、中間的な計算結果を得たい場合、以下のようにするとエラーが出る。

def squared_error_with_aux(x, y):
  return sum_squared_error(x, y), x-y

jax.grad(squared_error_with_aux)(x, y)
---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-9-7433a86e7375> in <module>()
      3 
----> 4 jax.grad(squared_error_with_aux)(x, y)

FilteredStackTrace: TypeError: Gradient only defined for scalar-output functions. Output was (DeviceArray(0.03999995, dtype=float32), DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).

The stack trace above excludes JAX-internal frames.

これは、jax.gradがスカラーを返す関数にのみ対応しているのに対し、tupleを返す関数を与えているからである。 この場合、引数has_auxをTrueにする。

jax.grad(squared_error_with_aux, has_aux=True)(x, y)
(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),
 DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))

has_aux = Trueは、与えられた関数は(out, aux)を返すということをjax.gradに知らせている。これにより、jax.gradはauxを無視し、outで計算した勾配とともにまとめて返す。

First JAX Training Loop

 ここまでの知識を用い、linear regressionの学習を行ってみる。 データは、 y = w_{true} x + b_{true}+ ε からサンプリングされるとする。

import numpy as np
import matplotlib.pyplot as plt

xs = np.random.normal(size=(100,))
noise = np.random.normal(scale=0.1, size=(100,))
ys = xs * 3 - 1 + noise

plt.scatter(xs, ys);

用いるモデルは、 \hat{y} (x; \theta ) = wx + bとする。 theta = [w, b]として二つのパラメータを \thetaにまとめ、modelはあくまでパラメータとデータの関数として定義する。

def model(theta, x):
  """Computes wx + b on a batch of input x."""
  w, b = theta
  return w * x + b

ロス関数は  J (x, y; \theta ) = \left( \hat{y} - y \right)^{2} とする。

def loss_fn(theta, x, y):
  prediction = model(theta, x)
  return jnp.mean((prediction-y)**2)

勾配降下法により、各更新ステップで各バラメータに対する損失の勾配を求め、最急降下する方向にわずかにパラメータを更新する。

 \theta_{new} = \theta - 0.1 \left( \nabla_{\theta} J \right) \left(x, y; \theta \right)

def update(theta, x, y, lr=0.1):
  return theta - lr * jax.grad(loss_fn)(theta, x, y)

JAXでは、ステップごとに呼び出されるupdate()関数を定義し、現在のパラメータを入力として新たなパラメータを返す。これはJAXの関数的性質の帰結なのだが、これも数式と同じように実装できて好き。

theta = jnp.array([1., 1.])

for _ in range(1000):
  theta = update(theta, xs, ys)

plt.scatter(xs, ys)
plt.plot(xs, model(theta, xs))

w, b = theta
print(f"w: {w:<.2f}, b: {b:<.2f}")
w: 3.00, b: -1.00

上で行った処理は、jax.jitでupdata()をデコレートすることで、JIT-compile(just-in-timeコンパイル)をすれば効率的に処理することが可能だが、jitに関してはまた後の記事で解説する。

JAXで実装されるトレーニングループはほとんどがこの例のように実装される。特に、のちの記事で説明するpytreeによって、パラメータをthetaにまとめjax.grad(loss_fn)(theta, x, y)のような形で勾配を計算することができた。このように数式と同じように実装できるのはほんとに好き。めっちゃ好き。ただし、今回はパラメータが二つしかなかったから一つのarrayにまとめることができたが、複雑なモデルになり多くのパラメータを扱う必要が出てくると、前述したように、パラメータは例えばネストされたarrayの辞書というような複雑な形で定義する。