izmyonの日記

奈良の山奥で研究にいそしむ大学院生の学習記録。

JAX学習記録③ーAutomatic Vectorization and Differentiation

Preface

 この記事では、JAX学習記録の第三弾として、JAXにおける自動微分(jax.grad)および関数の自動ベクトル化(jax.vmap)についてまとめる。まず、jax.gradとセットで使われることが多いAutomatic Vectorization (jax.vmap)について説明する。jax.vmapは、jax.gradやjax.jitなどと並び、JAXの大きな特徴の一つである「関数の変換を行う」JAX変換の一つである。その後、第一弾から登場しているjax.gradの使い方や機能について説明する。また、前の記事を読んでる前提で書くのでまだ読んでいない方は読むことをお勧めする。

izmyon.hatenablog.com

参考にした記事は、以下。

目次

Automatic Vectorization in JAX

 jax.vmapは、バッチ処理を行わない関数を、バッチ処理を行えるように変換する際に用いられるJAX変換である。例えばパラメータとバッチにまとめられていない単一のデータから損失を計算する関数を、バッチデータを受けとり、バッチ内の各データに対して損失を計算し、それをベクトルで返すような関数に変換する。おそらく出力にバッチ軸加え、ベクトル化させるために自動ベクトル化(Automatic Vectorization)という名前がついてるが、入力に対してもバッチ軸を追加することを意識してほしい。つまり、n階のテンソルデータを受けとりスカラー(0階のテンソル)を返す関数を、n+1階のテンソルデータを受けとり、ベクトル(1階のテンソル)を返す関数に変換する。ここでは、二つの一次元ベクトルの畳み込みを例に説明する。

Manual Vectorization

 まず、jax.vmapを使わずに実装する方法を見てみる。

import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)
DeviceArray([11., 20., 29.], dtype=float32)

ここでは、x=[0, 1, 2, 3, 4, 5]とw=[2., 3., 4.]の畳み込みを計算している。今回の計算では、[0*2.+1*3.+2*4., 1*2.+2*3.+3*4., 2*2.+3*3.+4*4.] = [11,, 20., 29.]という計算が行われている。

 次に、ここで定義されたconvolve関数を、以下のようにxとwの複数のペアからなるバッチに適用させることを考える。

xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

シンプルに実装するのであれば、以下のようにfor文を使い、バッチ内の各データに対して順番に関数を適用し、その結果を得る。

def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

manually_batched_convolve(xs, ws)
DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

この方法では正しい答えを得ることができるが、あまり効率的なやり方ではない。

 この計算を効率的に行うためには、関数をベクトル化された形式で行われるように書き換える必要がある。今回の例では、以下のようにconvolve関数を書き換えて、バッチ次元でのベクトル化された計算を行うことができる。

def manually_vectorized_convolve(xs, ws):
  output = []
  for i in range(1, xs.shape[-1] -1):
    output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
  return jnp.stack(output, axis=1)

manually_vectorized_convolve(xs, ws)
DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

このようにバッチに含まれるすべてのデータに対してまとめて計算を行うように実装を書き換えるには、インデックスや軸など入力の一部をどのように扱うかを変更する必要がある。このような再実装は厄介で、エラーが発生しやすい。そこで、JAXではこの再実装を自動化する仕組みを、jax.vmapで提供している。

Automatic Vectorization

 前節でみたように、ベクトル化された関数の実装を自動的に生成するのが、jax.vmapである。先の例の場合、次のようにjax.vmap関数をconvolve関数に作用させるだけで変換できる。

auto_batch_convolve = jax.vmap(convolve)
auto_batch_convolve(xs, ws)
DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

jax.vmapでは、各入力の最初にバッチ軸を自動的に追加する。もし入力したいデータにおいてバッチ次元が最初でない場合には、in_axesとout_axes引数を使用し、入力と出力におけるバッチ次元の位置を指定することができる。これらは、バッチ軸がすべての入力と出力と同じである場合は整数、そうでない場合はリストとなる。以下では、in_axes=1, out_axes=1として、入力データを転置して入力してみる。

auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xst = jnp.transpose(xs)
wst = jnp.transpose(ws)

auto_batch_convolve_v2(xst, wst)
DeviceArray([[11., 11.],
             [20., 20.],
             [29., 29.]], dtype=float32)

 jax.vmapは、引数の内の一つだけがバッチ処理される場合もサポートしている。例えば、ベクトルxのバッチと、単一の重みwに対して畳み込みたい場合、重みwのin_axes引数をNoneに設定する。

batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])
batch_convolve_v3(xs, w)
DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

Combining transformations

 また、他のJAX変換と同様に、jax.jitとjax.vmapはComposableに設計されている。つまり、vmapの関数をjitでラップしたり、JIT化した関数をvmapでラップすることができ、すべて正しく動作する。

jitted_batch_convolve = jax.jit(auto_batch_convolve)
jitted_batch_convolve(xs, ws)
DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

How to use Automatic Differentiation in JAX

 それではJAXにおける自動微分(autodiff)について、その使い方および機能を説明する。

Higher-order derivatives

 JAXのjax.gradによる自動微分では、数学関数 fを表すPython関数fを与えると、その数学関数の導関数 \nabla fを表すPython関数を返す。その関数もまた微分可能である場合にはさらにそれを微分することができるため、単にjax.gradによる関数の変換を繰り返すだけで高階微分を行うことができる。

Function of one variable

例えば、 f(x) = x^ 3+2x^ 2-3x+1の高階微分は以下のように求めることができる。

import jax

f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f)
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

 fの高階微分は、それぞれ

 \displaystyle  f'(x) = 3x^2+4x-3
 \displaystyle  f''(x) = 6x+4
 \displaystyle  f'''(x) = 6
 \displaystyle  f''''(x) = 0

である。 よって、 x=0では、

 \displaystyle  f'(0) = -3  \displaystyle  f''(0) = 4  \displaystyle  f'''(0) = 6  \displaystyle  f''''(0) = 0

であり、 x=1では、

 \displaystyle  f'(1) = 4 ,  \displaystyle  f''(1) = 10 ,  \displaystyle  f'''(1) = 6 ,  \displaystyle  f''''(1) = 0

である。 実際、以下を実行すると、

print('x=0')
print(dfdx(0.))
print(d2fdx(0.))
print(d3fdx(0.))
print(d4fdx(0.))
print('x=1')
print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
x=0
-3.0
4.0
6.0
0.0
x=1
4.0
10.0
6.0
0.0

となる。

Function of multiple variables

 多変数関数の場合には、高階微分はもっと複雑になる。多変数関数の二次の微分関数はヘッシアン行列によって表される。

 \displaystyle  (Hf)_{i,j} = \dfrac{\partial^2 f}{\partial_i \partial_j}

特に、二変数関数の場合は、

 \displaystyle  (H_f) = \dfrac{\partial^2 f}{\partial_i \partial_j} = \begin{pmatrix} \dfrac{\partial^2 f}{\partial_x^2}&\dfrac{\partial^2 f}{\partial_x \partial_y}\\ \dfrac{\partial^2 f}{\partial_y \partial_x}&\dfrac{\partial^2 f}{\partial_y^2} \end{pmatrix}

と表される。

多変数実数値関数 f: \mathbb{R}^n \rightarrow \mathbb{R}のヘッシアンは、以下のように、勾配のヤコビアンによって求められる。

 \displaystyle  (H_f) = J (\nabla f) = J \left( \dfrac{\partial f}{\partial_{x_1}}, \cdots , \dfrac{\partial f}{\partial_{x_n}} \right)
= \begin{pmatrix} 
\dfrac{\partial^2 f}{\partial_{x_1}^2}& \cdots &\dfrac{\partial^2 f}{\partial_{x_1} \partial_{x_n}}\\
\vdots & \ddots & \vdots \\
\dfrac{\partial^2 f}{\partial_{x_n} \partial_{x_1}}& \cdots &\dfrac{\partial^2 f}{\partial_{x_n}^2} \end{pmatrix}

JAXでは、関数のヤコビアンを求めるために、jax.jacfwdとjax.jacrevという二つのJAX変換を用意している。jax.jacfwdとjax.jacrevはそれぞれ、forward-mode、reverse-modeの自動微分に対応しており、その答えは一致する。(JAXのforward-mode、reverse-modeの計算方法及び計算効率の違いについては第四弾で説明する。)これを用いると、ヘッシアンは、以下のように書ける。

def hessian(f):
  return jax.jacfwd(jax.grad(f))

ドット積 f: x \rightarrow x^ T xのヘッシアンを求めることで確認してみる。ドット積のヘッシアンは、 if \quad i=j, \dfrac{\partial^ 2 f}{\partial_i \partial_j}(x) =2. \quad Otherwise, \dfrac{\partial^ 2 f}{\partial_i \partial_j}(x) =0である。

import jax.numpy as jnp

def f(x):
  return jnp.dot(x, x)

hessian(f)(jnp.array([1., 2., 3.]))
DeviceArray([[2., 0., 0.],
             [0., 2., 0.],
             [0., 0., 2.]], dtype=float32)

ちゃんと計算できてることが分かる。

Stopping Gradients

 JAXの自動微分は、関数の持つそれぞれの変数に対し、勾配を求めることができる。しかし時には、例えば計算グラフの一部に対する逆伝播を止めるといった、追加的な制御が必要になる場合がある。  例えばTD(0)強化学習の更新を考えてみる。この学習は、環境との相互作用の経験から、環境中の状態の値を推定するために使用される。ここで、状態 s_ {t-1}における価値の推定値関数 v_ {\theta}(s_ {t-1})が一時関数で定義されているとする。

# Value function and initial parameters
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])

報酬 r_tを観測する状態 s_tから状態 s_{t-1}への遷移を考える。

# An example transition.
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])

TD(0)学習ではネットワークパラメータを以下のように更新する。

 \delta \theta = (r_ t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})\nabla v_{\theta}(s_{t-1})

この更新では、損失関数の勾配は用いない。 しかし、ターゲット r_t + v_{\theta}(s_t)のパラメータ \theta への依存を無視すれば、疑似損失関数の勾配として書くことができる。

 L(\theta) = \left( r_t + v_{\theta}(s_t) - v_{\theta} (s_{t-1}) \right)^ 2

これをJAXでナイーブに実装すると、次のようになる。

def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return (target - v_tm1) ** 2

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

delta_theta
DeviceArray([ 2.4, -2.4,  2.4], dtype=float32)

しかし、td_updatはTD(0)の更新を計算しない。なぜなら、勾配の計算には、ターゲット r_t + v_{\theta}(s_t) \thetaへの依存が含まれているからである。

そこで、jax.lax.stop_gradientを使い、JAXにターゲットのパラメータ \thetaへの依存を無視させる。

def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return (jax.lax.stop_gradient(target) - v_tm1) ** 2

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

delta_theta
DeviceArray([-2.4, -4.8,  2.4], dtype=float32)

これにより、targetがパラメータに依存しないかのように扱い、パラメータの正しい更新を計算することができる。

jax.lax.stop_gradientは、例えば、他のパラメータが別の損失関数を使って学習されている場合、ある損失関数による勾配をニューラルネットワークのパラメータのある一部だけに対して求め、更新する際にも有効である。

また、stop_gradientは、ストレートスルー推定の実装にも使うことができる。ストレートスルー推定とは、微分不可能な関数の「勾配」を定義するためのトリックである。微分不可能な関数 f: \mathbb{R}^ n \rightarrow \mathbb{R}^ nが、勾配を求めたいより大きな関数の一部として用いられている場合、逆伝播を求める際には、 fが恒等関数であるかのように装う。

def f(x):
  return jnp.round(x)  # non-differentiable

def straight_through_f(x):
  # Create an exactly-zero expression with Sterbenz lemma that has
  # an exactly-one gradient.
  zero = x - jax.lax.stop_gradient(x)
  return zero + jax.lax.stop_gradient(f(x))

print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))

print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
f(x):  3.0
straight_through_f(x): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0

Per-example gradients

 多くのMLシステムでは、バッチデータから勾配を計算しパラメータの更新を行うが、計算効率や分散の削減などの理由から、バッチ内の各特定サンプルに関連する勾配や更新にアクセスすることが必要な場合がある。

例えば、勾配の大きさに基づいてデータに優先順位をつけたり、サンプル単位でクリッピングや正規化を行う場合である。

多くのフレームワーク(PyTorch、TF、Theano)では、ライブラリがバッチ全体の勾配を直接蓄積するため、サンプルごとの勾配を計算するのは簡単ではないことが多い。しかも、サンプルごとに個別に損失を計算し、その結果の勾配を集計するような素朴な回避策は、一般に非常に非効率である。

JAXでは、サンプルごとに勾配を計算するコードを、jit、vmap、gradの各変換を一緒にするだけで、簡単かつ効率的に実装することができる。

perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))

# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)

各変換について一つずつ見てみる。

まず、td_lossにjax.gradを作用させ、単一(バッチ単位ではない)の入力におけるパラメータに対する損失の勾配を計算する関数を得る。

dtdloss_dtheta = jax.grad(td_loss)

dtdloss_dtheta(theta, s_tm1, r_t, s_t)
DeviceArray([-2.4, -4.8,  2.4], dtype=float32)

この関数は、先の配列の1行を計算する。

次に、この関数をjax.vmapでベクトル化する。これにより、すべての入力と出力にバッチ次元が追加され、バッチ内の各出力は、入力バッチ内の対応する各データに対する勾配である。

almost_perex_grads = jax.vmap(dtdloss_dtheta)

batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)

これは我々が望むものとは異なる。なぜなら、我々は実際には単一のthetaを利用したいのであるが、この関数にthetaのバッチを手動で入力しなければならないからである。しかし、前の記事でも書いた通り、jax.vmapにin_axesを追加し、thetaをNone、他の引数を0と指定することで解決できる。これにより、結果として得られる関数は、他の引数にのみバッチ軸を追加し、thetaに対してはバッチ化されていない、単一のthetaを用いることができる。

inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))

inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)

これで欲しい機能は実装できたものの、計算効率が悪く必要以上に遅い。そこで、jax.jitをラップして、同じ関数のコンパイル済みで効率的なバージョンを用いる。

perex_grads = jax.jit(inefficient_perex_grads)

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
100 loops, best of 5: 7.74 ms per loop
10000 loops, best of 5: 86.2 µs per loop

かなり早くなった。

次は、自動微分の実装について書く。