izmyonの日記

奈良の山奥で研究にいそしむ大学院生の学習記録。

JAX学習記録④ーHow Autodiff is made?

Preface

 この記事ではJAX学習記録シリーズ第四弾として、JAXの自動微分がどのように実現されているか、詳細を見てみる。正直なところ、特に数学的な説明に関して理解しきれていないところが多く、基本的には原文に忠実に訳しており、自分の理解による説明はあまり加えられていない。もし修正が必要な個所があったり、もっと分かりやすい説明の仕方があれば、修正案または補足案と共に教えて頂けると非常にありがたい。また、この記事は以下の第三弾の続きであり、これを読んでる前提で書くのでまだ読んでいない方は読むことをお勧めする。

izmyon.hatenablog.com

 今回の内容の元記事は、以下。

目次

How it's made: two foundational autodiff functions

 JAXはforward-mode、reverse-modeの自動微分を持つ。ちなみに、おなじみのjax.gradは、reverse-modeをベースとして作られている。この章では、これら二つのmodeの違いを説明し、それをうまく使うための説明をおこなう。

Jacobian-Vector products (JVPs, aka forward-mode autodiff)

JVPs in math

 数学的には、ある関数 f: \mathbb{R}^{n} \rightarrow \mathbb{R}^{m}が与えられたとき、 \partial f(x)で表されるある入力点 x \in \mathbb{R}^nにおける fヤコビアン(日本語では行列自体はヤコビ行列と呼び、その行列式ヤコビアンと呼んで区別することもあるが、ここでは原文に従い、行列の方をヤコビアンと呼ぶことで統一する。また、行列式の方は出てこない。)は、多くの場合 \mathbb{R}^{m×n}の行列の形をしている。 \partial f(x)は、空間 \mathbb{R}^nにおける入力点 xの接空間を、空間 \mathbb{R}^mにおける出力点 f(x)の接空間に移す線形写像となる。

 \partial f(x): \mathbb{R}^ n \rightarrow \mathbb{R}^ m

この変換は f xにおける押し出し(pushforward) と呼ばれ、ヤコビアンはこの線形写像を標準基底で表したものである。もし、特定の入力点にこだわらないのであれば、 \partial fを最初に入力点を取り、その点のヤコビアンを線形写像として返す関数であると考えることができる。

 \partial f(x): \mathbb{R}^ n \rightarrow \mathbb{R}^ n \rightarrow \mathbb{R}^ m

特に、入力点 x \in \mathbb{R}^nと接(勾配)ベクトル v \in \mathbb{R}^nが与えられたとき、出力の接ベクトル \mathbb{R}^mを得ることができると考えることができる。この (x, v)ペアから出力となる接ベクトルへの写像を、ヤコビアン-ベクトル積と呼び、次のように書く。

 (x, v) \rightarrow \partial f(x)v

JVPs in JAX code

 JAXでは、jvpという関数でこの変換を実装している。数学関数 fを評価するPython関数が与えられたとき、JAXのjvpは、  (x, v) \rightarrow (f(x), \partial f(x)v)を評価するPython関数を返す。

from jax import jvp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)

# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))

Haskellのように型を定義するとするなら、以下のようになる。

jvp :: (a -> b) -> a -> T a -> (b, T b)

つまり、jvpはa->b型の関数、a型の値、T a型の接ベクトルを引数にとり、b型の値とT b型の出力接ベクトルの組を返す。

jvp変換された関数は、元の関数と同様に評価されるが、a型の各原始値と対になり、T a型の接ベクトル値を押し出す。元の関数が適用するはずの各原始数値演算に対して、jvp変換された関数は、原始値に対する原始値の評価と、それらの原始値における原始値のJVP適用という「JVP規則」を実行する。

この評価方法は、計算量に影響を与える。JVPを評価しながら計算するので、後で保存する必要がなく、メモリコストは計算の深さに関係なく発生する。さらに、jvp変換された関数のFLOPコストは、関数を評価するだけのコストの約3倍(sin(x)などの元の関数の評価に1ユニット、cos(x)などの線形化に1ユニット、cos_x * vなど線形化した関数をベクトルに適用して1ユニットが必要)となる。別の言い方をすれば、固定原点 xに対して、 fを評価するのとほぼ同じ限界コストで v \rightarrow \partial f(x) \cdot vを評価することができる。

このメモリの複雑さは、かなり説得力がある。では、なぜ機械学習ではforward-modeをあまり見かけないのだろうか?

その答えとして、まず、JVPを使って完全なヤコビアン行列を構築する方法を考えてみる。もし、JVPを1つのone-hotの接ベクトルに適用すると、ヤコビアン行列の1列が明らかになり、入力した非ゼロの要素に対応する。つまり、一度に1列ずつ完全なヤコビアンを構築することができ、各列を得るのに、1回の関数評価と同じぐらいのコストがかかる。これは「高い」(行方向の要素数が列方向の要素数より多い。)ヤコビアンを持つ関数には効率的だが、「広い」(行方向の要素数が列方向の要素数より少ない。)ヤコビアンには非効率的である。

機械学習で勾配に基づく最適化を行う場合、おそらくはパラメータとデータからスカラー損失値を求める損失関数を最小化したいと思うことが多いだろう。これは、この関数のヤコビアンが非常に広い行列であることを意味する。この行列を1列ずつ構築し、それぞれの呼び出しで元の関数を評価するために同数のFLOPsを必要とするのは、確かに非効率的である。特に、学習損失関数が数百万から数十億にもなるニューラルネットワークを学習する場合、この方法では、スケールアップは望めない。

これらの理由から、reverse-modeを用いる。

Vector-Jacobian products (VJPs, aka reverse-mode autodiff)

 forward-modeがヤコビアン-ベクトル積を評価する関数を返し、それを使ってヤコビアンを一列ずつ作ることができたのに対し、reverse-modeでは、ベクトル-ヤコビアン積(ヤコビアン-転置ベクトル積と同等)の評価関数を返す方法で、一行ずつヤコビアンを作ることができる。

VJPs in math

 もう一度関数 f: \mathbb{R}^n \rightarrow \mathbb{R}^mについて考える。VJPは以下で表される。

 (x, v) \rightarrow v\partial f(x)

ここで、 v fxにおける余接空間の要素である。厳密には、 vは線形写像 v: \mathbb{R}^m \rightarrow \mathbb{R}であると考えるべきであり、 v \partial f(x)と書くと、合成関数 v \circ \partial f(x)を意味していると考える。しかし、多くの場合、 vは、単に\mathbb{R}^mのベクトルであると考え、両者を互換的に用いることができる。

また、VJPの線形部はJVPの線形部の転置(または随伴行列)であると考えることもできる。

 (x, v) \rightarrow \partial f(x)^ T v

余接空間上に対応する写像は、しばしば f xにおける引き戻し(pullback)と呼ばれる。重要なのは、 fの出力のようなものから、入力のようなものへ引き戻すことである。

VJPs in JAX code

 JAX変換vjpでは、数学関数 fを評価するPython関数が与えられたとき、VJP  (x, v) \rightarrow (f(x), v^ T \partial f(x)) を評価するPython関数を返すということである。

from jax import vjp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

y, vjp_fun = vjp(f, W)

key, subkey = random.split(key)
u = random.normal(subkey, y.shape)

# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)

Haskellのように型を定義するとするなら、以下のようになる。

vjp :: (a -> b) -> a -> (b, CT b -> CT a)

つまり、vjpはa -> b型の関数とa型の点を引数にとり、b型の値とCT b -> CT a型の線形写像からなるペアを返す。

これは、ヤコビアン行列を一度に1行ずつ構築することを可能にし、[tex: (x, v) \rightarrow (f(x), vT \partial f(x))]を評価するためのFLOPコストは、 fの評価のコストの約3倍となる。特に、ある関数 f: \mathbb{R}^n \rightarrow \mathbb{R}の勾配を求める場合、1回の呼び出しで求めることができる。このようにgradは、数百万から数十億のパラメータを持つニューラルネットワークの学習損失関数のような目的に対しても、勾配に基づく最適化を効率的に行うことができる。

しかし、gradにはコストがかかる。FLOPsが少ないとはいえ、メモリは計算の深さに比例して増加する。また、JAXにはいくつかのトリックがあるが、伝統的に実装は順方向モードよりも複雑となる。

reverse-modeの仕組みについては、2017年に開催されたDeep Learning Summer Schoolのチュートリアルビデオを観ると良い。

Hessian-vector products using both forward- and reverse-mode

 前の記事(第三弾)で、ヘッシアン-ベクトル積関数を、(二階導関数の連続性を仮定して、)reverse-modeで実装した。

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

この実装は効率的ではあるが、forward-modeとreverse-modeを組み合わせることで、メモリを節約しつつより効率的に実装できる。

数学的には、微分する関数 f: \mathbb{R}^n \rightarrow \mathbb{R}、関数を線形化する点 x \in \mathbb{R}^ n、ベクトル v \in \mathbb{R}^nが与えられたとき、ヘッシアン-ベクトル積関数は以下のようになる。

 (x, v) \rightarrow \partial ^ 2 f(x) v

 f微分(または勾配)を g: \mathbb{R}^n \rightarrow \mathbb{R}^nとすると、 g(x) = \partial f(x)である。求めるのはそのJVPであり、

 (x, v) \rightarrow \partial g(x) v = \partial ^ 2 f(x) v

を得る。

from jax import jvp, grad

# forward-over-reverse
def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]

さらに良いことに、jnp.dotを直接呼び出す必要がないため、このhvp関数はpytreeのように、どんな形の配列でも、どんなコンテナ型でも動作する。しかも、jax.numpyに依存することさえない。

以下は、hvp関数の使用例である。

def f(X):
  return jnp.sum(jnp.tanh(X)**2)

key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))

ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)

print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True

以前定義したhessian関数でも、hvp関数でも、ほぼ同様の答えが得られていることが分かる。

もう一つの方法として、reverse-over-forwardを使う方法がある。

# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
  g = lambda primals: jvp(f, primals, tangents)[1]
  return grad(g)(primals)

forward-modeは、reverse-modeよりもオーバーヘッドが少なく、また、ここでの外側の微分演算子は内側の演算子よりも多くの微分しなければならないので、forward-modeを外側にしておくことが最もうまくいくのである。

# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
  x, = primals
  v, = tangents
  return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)


print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))

print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
5.45 ms ± 187 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
9.52 ms ± 4.27 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
14.1 ms ± 7.22 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
53.6 ms ± 831 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Composing VJPs, JVPs, and vmap

Jacobian-Matrix and Matrix-Jacobian products

 これで、jvpとvjp関数がそろい、一度に一つのベクトルを押し出し(pushforward)たり、引き戻し(pullback)たりすることができるようになった。ここで、一番最初に紹介したJAXのvmap変換を使うと、基底全体を押し出したり、引き戻したりすることができるようになる。特に、行列-ヤコビアン積およびヤコビアン-行列積を早く計算できるようになる。

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    return jnp.vstack([vjp_fun(mi) for mi in M])

# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    outs, = vmap(vjp_fun)(M)
    return outs

key = random.PRNGKey(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)

loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)

print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
132 ms ± 398 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Matrix-Jacobian product
5.53 ms ± 109 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
def loop_jmp(f, W, M):
    # jvp immediately returns the primal and tangent values as a tuple,
    # so we'll compute and select the tangents in a list comprehension
    return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])

def vmap_jmp(f, W, M):
    _jvp = lambda s: jvp(f, (W,), (s,))[1]
    return vmap(_jvp)(M)

num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)

loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
246 ms ± 494 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Jacobian-Matrix product
2.94 ms ± 20.9 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

The implementation of jacfwd and jacrev

 前の節で高速なヤコビアン-行列積、行列-ヤコビアン積を見てきた。同じ手法で標準基底全体を一度に押し出し(pushforward)または引き戻し(pullback)することで、jacfwdとfacrevを実装する。

from jax import jacrev as builtin_jacrev

def our_jacrev(f):
    def jacfun(x):
        y, vjp_fun = vjp(f, x)
        # Use vmap to do a matrix-Jacobian product.
        # Here, the matrix is the Euclidean basis, so we get all
        # entries in the Jacobian at once. 
        J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
        return J
    return jacfun

assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd

def our_jacfwd(f):
    def jacfun(x):
        _jvp = lambda s: jvp(f, (x,), (s,))[1]
        Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
        return jnp.transpose(Jt)
    return jacfun

assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'

興味深いのは、AutoGradでは、これはできなかったことである。Autogradでのreverse-modeでのヤコビアンの実装は、map関数のループで一度に一つのベクトルを引き戻す(pullback)する必要があった。しかし、一度に一つのベクトルを計算で押し出す(pushforward)ことは、vmapですべてのバッチ処理をするよりもはるかに効率が悪いのである。

もう一つ、AutoGradでできなかったことは、jit(just-in-timeコンパイル)である。面白いことに、微分する関数にどれだけPythonのダイナミズムを使っても、計算の線形部分には常にjitを使うことができる。

def f(x):
    try:
        if x < 3:
            return 2 * x ** 3
        else:
            raise ValueError
    except ValueError:
        return jnp.pi * x

y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(DeviceArray(3.1415927, dtype=float32, weak_type=True),)

Complex numbers and differentiation

 JAXは複素数と、その微分も扱うことができる。正則関数と非正則関数の両方の微分を行うには、JVPとVJPの両方の観点から考える必要がある。

まず、複素数から複素数への関数 f : \mathbb{C} \rightarrow \mathbb{C}を考え、それに対応する関数 g: \mathbb{R}^2 \rightarrow \mathbb{R}^2を同定する。

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return u(x, y) + v(x, y) * 1j

def g(x, y):
  return (u(x, y), v(x, y))

ここでは、複素数 z=x+yiを受けとり、 f(z) = u(x,y) + v(x,y)iを返す関数と、複素数空間 \mathbb{C} \mathbb{R}^2と見なす関数 gを定義した。 関数 gは実数の入力と出力しか持たないため、例え接ベクトル (c,d) \in \mathbb{R}^2 を使ったヤコビアン-ベクトル積は以下のようになる。

 \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x,y)  \\ \partial_0 v(x, y) & \partial_1 v(x,y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}

元の関数 fを接ベクトル c+di \in \mathbb{C}に適用してJVPを得るには、同じ定義を使って、結果を別の複素数として特定するだけで良い。

 \displaystyle \partial f (x+yi)(c+di) = 
\begin{pmatrix} 1 & i \end{pmatrix} \begin{pmatrix} \partial_0 u(x, y) & \partial_1 u(x,y)  \\
\partial_0 v(x, y) & \partial_1 v(x,y) \end{pmatrix} \begin{pmatrix} c \\ d \end{pmatrix}

これが関数JVPの定義である。ここで、 fが正則かどうかは関係ない。

以下で確認してみる。

def check(seed):
  key = random.PRNGKey(seed)

  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # tangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_dot = c + d * 1j

  # check jvp
  _, ans = jvp(fun, (z,), (z_dot,))
  expected = (grad(u, 0)(x, y) * c +
              grad(u, 1)(x, y) * d +
              grad(v, 0)(x, y) * c * 1j+
              grad(v, 1)(x, y) * d * 1j)
  print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True

次にVJPについてみてみる。余接ベクトル c + di \in \mathbb{C}に対して、 fのVJPを以下のように定義する。

 (c+di)^* \partial f (x+yi) = \begin{bmatrix} c & -d \end{bmatrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x,y)  \\ 
\partial_0 v(x, y) & \partial_1 v(x,y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix}

ここで、複素共役と余接を扱っているために、マイナスが出てきている。

以下で確認してみる。

def check(seed):
  key = random.PRNGKey(seed)

  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # cotangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_bar = jnp.array(c + d * 1j)  # for dtype control

  # check vjp
  _, fun_vjp = vjp(fun, z)
  ans, = fun_vjp(z_bar)
  expected = (grad(u, 0)(x, y) * c +
              grad(v, 0)(x, y) * (-d) +
              grad(u, 1)(x, y) * c * (-1j) +
              grad(v, 1)(x, y) * (-d) * (-1j))
  assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)
True
True
True

それでは、grad、jacfwd、jacrevのようなラッパーに関してはどうだろうか?

 \mathbb{R} \rightarrow \mathbb{R}の関数の場合、grad(f)(x)をvjp(f, x)[1](1.0)と定義したことを踏まえると、VJPを1.0の値に適用させると勾配(またはヤコビアン微分)を求めることができる。 \mathbb{C} \rightarrow \mathbb{C}の関数の場合でも同様であり、1.0を余接ベクトルとして利用し、ヤコビアン複素数で表した結果を得ることができる。

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return x**2 + y**2

z = 3. + 4j
grad(f)(z)
DeviceArray(6.-8.j, dtype=complex64)

一般的な \mathbb{C} \rightarrow \mathbb{C}関数では、ヤコビアンは実数値の自由度が4である(上で定義した2×2のヤコビアンのように)ため、そのすべてを複素数で表現することは望めない。しかし、正則関数であればそれが可能である。正則関数は、その微分が一つの複素数表現できるという特殊な性質を持っている特殊な \mathbb{C} \rightarrow \mathbb{C}関数である(コーシー・リーマン方程式が、上記の2×2のヤコビアンが、複素数平面上でスケールと回転の行列という特殊な形式、つまり、乗算を行うことで、1つの複素数の作用を持つことを示している)。そして、1.0の余ベクトル(covector)を持つvjpを一度呼び出すだけで、その複素数を求めることができる。

これは正則関数にしか使えないため、これを利用するには、以下のようにholomorphic=Trueとし、関数が正則であることをJAXに約束する。もしそうでなければ、JAXはgradを複素数が出力の関数に使ったときにエラーを出す。

def f(z):
  return jnp.sin(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)
DeviceArray(-27.034946-3.8511534j, dtype=complex64, weak_type=True)

holomorphic=Trueの約束をすることで、出力が複素数値であるときのエラーを無効にする。関数が正則ではない場合にも、holomorphic=Trueと書けるが、出力されるのは完全なヤコビアンを表すものではなく、出力の虚数部を捨てた関数のヤコビアンが出力される。

def f(z):
  return jnp.conjugate(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)  # f is not actually holomorphic!
DeviceArray(1.-0.j, dtype=complex64, weak_type=True)

ここでは、gradの働きについて、いくつかの有用な知見を示す。

  1. 正則関数にgradを使うことができる。
  2. 例えば、複素数パラメータxの実数値損失関数のように、grad(f)(x)の共役の方向にパラメータを更新することによって、関数 f: \mathbb{C} \rightarrow \mathbb{R}を最適化するためにgradを使用することができる。
  3. もし、内部でたまたま複素数演算(畳み込みに使われるFFTのような非正則演算)を使っている関数 \mathbb{R} \rightarrow \mathbb{R}があったとしても、gradは機能し、実数値だけを使った実装と同じ結果を得ることができる。

いずれにしろ、JVPとVJPは常に曖昧性を持たず、もし非正則関数の完全なヤコビアンを計算したい場合、JVPやVJPで計算することができる。

JAXでは、複素数はどこでも使えると思ってよい。以下は、複素数行列をコレスキー分解で微分している。

A = jnp.array([[5.,    2.+3j,    5j],
              [2.-3j,   7.,  1.+7j],
              [-5j,  1.-7j,    12.]])

def f(X):
    L = jnp.linalg.cholesky(X)
    return jnp.sum((L - jnp.sin(L))**2)

grad(f, holomorphic=True)(A)
DeviceArray([[-0.7534182  +0.j      , -3.0509028 -10.940544j,
               5.9896846  +3.542303j],
             [-3.0509028 +10.940544j, -8.904491   +0.j      ,
              -5.1351523  -6.559373j],
             [ 5.9896846  -3.542303j, -5.1351523  +6.559373j,
               0.01320427 +0.j      ]], dtype=complex64)

次は、jitについて書く。

JAX学習記録③ーAutomatic Vectorization and Differentiation

Preface

 この記事では、JAX学習記録の第三弾として、JAXにおける自動微分(jax.grad)および関数の自動ベクトル化(jax.vmap)についてまとめる。まず、jax.gradとセットで使われることが多いAutomatic Vectorization (jax.vmap)について説明する。jax.vmapは、jax.gradやjax.jitなどと並び、JAXの大きな特徴の一つである「関数の変換を行う」JAX変換の一つである。その後、第一弾から登場しているjax.gradの使い方や機能について説明する。また、前の記事を読んでる前提で書くのでまだ読んでいない方は読むことをお勧めする。

izmyon.hatenablog.com

参考にした記事は、以下。

目次

Automatic Vectorization in JAX

 jax.vmapは、バッチ処理を行わない関数を、バッチ処理を行えるように変換する際に用いられるJAX変換である。例えばパラメータとバッチにまとめられていない単一のデータから損失を計算する関数を、バッチデータを受けとり、バッチ内の各データに対して損失を計算し、それをベクトルで返すような関数に変換する。おそらく出力にバッチ軸加え、ベクトル化させるために自動ベクトル化(Automatic Vectorization)という名前がついてるが、入力に対してもバッチ軸を追加することを意識してほしい。つまり、n階のテンソルデータを受けとりスカラー(0階のテンソル)を返す関数を、n+1階のテンソルデータを受けとり、ベクトル(1階のテンソル)を返す関数に変換する。ここでは、二つの一次元ベクトルの畳み込みを例に説明する。

Manual Vectorization

 まず、jax.vmapを使わずに実装する方法を見てみる。

import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)
DeviceArray([11., 20., 29.], dtype=float32)

ここでは、x=[0, 1, 2, 3, 4, 5]とw=[2., 3., 4.]の畳み込みを計算している。今回の計算では、[0*2.+1*3.+2*4., 1*2.+2*3.+3*4., 2*2.+3*3.+4*4.] = [11,, 20., 29.]という計算が行われている。

 次に、ここで定義されたconvolve関数を、以下のようにxとwの複数のペアからなるバッチに適用させることを考える。

xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

シンプルに実装するのであれば、以下のようにfor文を使い、バッチ内の各データに対して順番に関数を適用し、その結果を得る。

def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

manually_batched_convolve(xs, ws)
DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

この方法では正しい答えを得ることができるが、あまり効率的なやり方ではない。

 この計算を効率的に行うためには、関数をベクトル化された形式で行われるように書き換える必要がある。今回の例では、以下のようにconvolve関数を書き換えて、バッチ次元でのベクトル化された計算を行うことができる。

def manually_vectorized_convolve(xs, ws):
  output = []
  for i in range(1, xs.shape[-1] -1):
    output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
  return jnp.stack(output, axis=1)

manually_vectorized_convolve(xs, ws)
DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

このようにバッチに含まれるすべてのデータに対してまとめて計算を行うように実装を書き換えるには、インデックスや軸など入力の一部をどのように扱うかを変更する必要がある。このような再実装は厄介で、エラーが発生しやすい。そこで、JAXではこの再実装を自動化する仕組みを、jax.vmapで提供している。

Automatic Vectorization

 前節でみたように、ベクトル化された関数の実装を自動的に生成するのが、jax.vmapである。先の例の場合、次のようにjax.vmap関数をconvolve関数に作用させるだけで変換できる。

auto_batch_convolve = jax.vmap(convolve)
auto_batch_convolve(xs, ws)
DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

jax.vmapでは、各入力の最初にバッチ軸を自動的に追加する。もし入力したいデータにおいてバッチ次元が最初でない場合には、in_axesとout_axes引数を使用し、入力と出力におけるバッチ次元の位置を指定することができる。これらは、バッチ軸がすべての入力と出力と同じである場合は整数、そうでない場合はリストとなる。以下では、in_axes=1, out_axes=1として、入力データを転置して入力してみる。

auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xst = jnp.transpose(xs)
wst = jnp.transpose(ws)

auto_batch_convolve_v2(xst, wst)
DeviceArray([[11., 11.],
             [20., 20.],
             [29., 29.]], dtype=float32)

 jax.vmapは、引数の内の一つだけがバッチ処理される場合もサポートしている。例えば、ベクトルxのバッチと、単一の重みwに対して畳み込みたい場合、重みwのin_axes引数をNoneに設定する。

batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])
batch_convolve_v3(xs, w)
DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

Combining transformations

 また、他のJAX変換と同様に、jax.jitとjax.vmapはComposableに設計されている。つまり、vmapの関数をjitでラップしたり、JIT化した関数をvmapでラップすることができ、すべて正しく動作する。

jitted_batch_convolve = jax.jit(auto_batch_convolve)
jitted_batch_convolve(xs, ws)
DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

How to use Automatic Differentiation in JAX

 それではJAXにおける自動微分(autodiff)について、その使い方および機能を説明する。

Higher-order derivatives

 JAXのjax.gradによる自動微分では、数学関数 fを表すPython関数fを与えると、その数学関数の導関数 \nabla fを表すPython関数を返す。その関数もまた微分可能である場合にはさらにそれを微分することができるため、単にjax.gradによる関数の変換を繰り返すだけで高階微分を行うことができる。

Function of one variable

例えば、 f(x) = x^ 3+2x^ 2-3x+1の高階微分は以下のように求めることができる。

import jax

f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f)
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

 fの高階微分は、それぞれ

 \displaystyle  f'(x) = 3x^2+4x-3
 \displaystyle  f''(x) = 6x+4
 \displaystyle  f'''(x) = 6
 \displaystyle  f''''(x) = 0

である。 よって、 x=0では、

 \displaystyle  f'(0) = -3  \displaystyle  f''(0) = 4  \displaystyle  f'''(0) = 6  \displaystyle  f''''(0) = 0

であり、 x=1では、

 \displaystyle  f'(1) = 4 ,  \displaystyle  f''(1) = 10 ,  \displaystyle  f'''(1) = 6 ,  \displaystyle  f''''(1) = 0

である。 実際、以下を実行すると、

print('x=0')
print(dfdx(0.))
print(d2fdx(0.))
print(d3fdx(0.))
print(d4fdx(0.))
print('x=1')
print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
x=0
-3.0
4.0
6.0
0.0
x=1
4.0
10.0
6.0
0.0

となる。

Function of multiple variables

 多変数関数の場合には、高階微分はもっと複雑になる。多変数関数の二次の微分関数はヘッシアン行列によって表される。

 \displaystyle  (Hf)_{i,j} = \dfrac{\partial^2 f}{\partial_i \partial_j}

特に、二変数関数の場合は、

 \displaystyle  (H_f) = \dfrac{\partial^2 f}{\partial_i \partial_j} = \begin{pmatrix} \dfrac{\partial^2 f}{\partial_x^2}&\dfrac{\partial^2 f}{\partial_x \partial_y}\\ \dfrac{\partial^2 f}{\partial_y \partial_x}&\dfrac{\partial^2 f}{\partial_y^2} \end{pmatrix}

と表される。

多変数実数値関数 f: \mathbb{R}^n \rightarrow \mathbb{R}のヘッシアンは、以下のように、勾配のヤコビアンによって求められる。

 \displaystyle  (H_f) = J (\nabla f) = J \left( \dfrac{\partial f}{\partial_{x_1}}, \cdots , \dfrac{\partial f}{\partial_{x_n}} \right)
= \begin{pmatrix} 
\dfrac{\partial^2 f}{\partial_{x_1}^2}& \cdots &\dfrac{\partial^2 f}{\partial_{x_1} \partial_{x_n}}\\
\vdots & \ddots & \vdots \\
\dfrac{\partial^2 f}{\partial_{x_n} \partial_{x_1}}& \cdots &\dfrac{\partial^2 f}{\partial_{x_n}^2} \end{pmatrix}

JAXでは、関数のヤコビアンを求めるために、jax.jacfwdとjax.jacrevという二つのJAX変換を用意している。jax.jacfwdとjax.jacrevはそれぞれ、forward-mode、reverse-modeの自動微分に対応しており、その答えは一致する。(JAXのforward-mode、reverse-modeの計算方法及び計算効率の違いについては第四弾で説明する。)これを用いると、ヘッシアンは、以下のように書ける。

def hessian(f):
  return jax.jacfwd(jax.grad(f))

ドット積 f: x \rightarrow x^ T xのヘッシアンを求めることで確認してみる。ドット積のヘッシアンは、 if \quad i=j, \dfrac{\partial^ 2 f}{\partial_i \partial_j}(x) =2. \quad Otherwise, \dfrac{\partial^ 2 f}{\partial_i \partial_j}(x) =0である。

import jax.numpy as jnp

def f(x):
  return jnp.dot(x, x)

hessian(f)(jnp.array([1., 2., 3.]))
DeviceArray([[2., 0., 0.],
             [0., 2., 0.],
             [0., 0., 2.]], dtype=float32)

ちゃんと計算できてることが分かる。

Stopping Gradients

 JAXの自動微分は、関数の持つそれぞれの変数に対し、勾配を求めることができる。しかし時には、例えば計算グラフの一部に対する逆伝播を止めるといった、追加的な制御が必要になる場合がある。  例えばTD(0)強化学習の更新を考えてみる。この学習は、環境との相互作用の経験から、環境中の状態の値を推定するために使用される。ここで、状態 s_ {t-1}における価値の推定値関数 v_ {\theta}(s_ {t-1})が一時関数で定義されているとする。

# Value function and initial parameters
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])

報酬 r_tを観測する状態 s_tから状態 s_{t-1}への遷移を考える。

# An example transition.
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])

TD(0)学習ではネットワークパラメータを以下のように更新する。

 \delta \theta = (r_ t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})\nabla v_{\theta}(s_{t-1})

この更新では、損失関数の勾配は用いない。 しかし、ターゲット r_t + v_{\theta}(s_t)のパラメータ \theta への依存を無視すれば、疑似損失関数の勾配として書くことができる。

 L(\theta) = \left( r_t + v_{\theta}(s_t) - v_{\theta} (s_{t-1}) \right)^ 2

これをJAXでナイーブに実装すると、次のようになる。

def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return (target - v_tm1) ** 2

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

delta_theta
DeviceArray([ 2.4, -2.4,  2.4], dtype=float32)

しかし、td_updatはTD(0)の更新を計算しない。なぜなら、勾配の計算には、ターゲット r_t + v_{\theta}(s_t) \thetaへの依存が含まれているからである。

そこで、jax.lax.stop_gradientを使い、JAXにターゲットのパラメータ \thetaへの依存を無視させる。

def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return (jax.lax.stop_gradient(target) - v_tm1) ** 2

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

delta_theta
DeviceArray([-2.4, -4.8,  2.4], dtype=float32)

これにより、targetがパラメータに依存しないかのように扱い、パラメータの正しい更新を計算することができる。

jax.lax.stop_gradientは、例えば、他のパラメータが別の損失関数を使って学習されている場合、ある損失関数による勾配をニューラルネットワークのパラメータのある一部だけに対して求め、更新する際にも有効である。

また、stop_gradientは、ストレートスルー推定の実装にも使うことができる。ストレートスルー推定とは、微分不可能な関数の「勾配」を定義するためのトリックである。微分不可能な関数 f: \mathbb{R}^ n \rightarrow \mathbb{R}^ nが、勾配を求めたいより大きな関数の一部として用いられている場合、逆伝播を求める際には、 fが恒等関数であるかのように装う。

def f(x):
  return jnp.round(x)  # non-differentiable

def straight_through_f(x):
  # Create an exactly-zero expression with Sterbenz lemma that has
  # an exactly-one gradient.
  zero = x - jax.lax.stop_gradient(x)
  return zero + jax.lax.stop_gradient(f(x))

print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))

print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
f(x):  3.0
straight_through_f(x): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0

Per-example gradients

 多くのMLシステムでは、バッチデータから勾配を計算しパラメータの更新を行うが、計算効率や分散の削減などの理由から、バッチ内の各特定サンプルに関連する勾配や更新にアクセスすることが必要な場合がある。

例えば、勾配の大きさに基づいてデータに優先順位をつけたり、サンプル単位でクリッピングや正規化を行う場合である。

多くのフレームワーク(PyTorch、TF、Theano)では、ライブラリがバッチ全体の勾配を直接蓄積するため、サンプルごとの勾配を計算するのは簡単ではないことが多い。しかも、サンプルごとに個別に損失を計算し、その結果の勾配を集計するような素朴な回避策は、一般に非常に非効率である。

JAXでは、サンプルごとに勾配を計算するコードを、jit、vmap、gradの各変換を一緒にするだけで、簡単かつ効率的に実装することができる。

perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))

# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)

各変換について一つずつ見てみる。

まず、td_lossにjax.gradを作用させ、単一(バッチ単位ではない)の入力におけるパラメータに対する損失の勾配を計算する関数を得る。

dtdloss_dtheta = jax.grad(td_loss)

dtdloss_dtheta(theta, s_tm1, r_t, s_t)
DeviceArray([-2.4, -4.8,  2.4], dtype=float32)

この関数は、先の配列の1行を計算する。

次に、この関数をjax.vmapでベクトル化する。これにより、すべての入力と出力にバッチ次元が追加され、バッチ内の各出力は、入力バッチ内の対応する各データに対する勾配である。

almost_perex_grads = jax.vmap(dtdloss_dtheta)

batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)

これは我々が望むものとは異なる。なぜなら、我々は実際には単一のthetaを利用したいのであるが、この関数にthetaのバッチを手動で入力しなければならないからである。しかし、前の記事でも書いた通り、jax.vmapにin_axesを追加し、thetaをNone、他の引数を0と指定することで解決できる。これにより、結果として得られる関数は、他の引数にのみバッチ軸を追加し、thetaに対してはバッチ化されていない、単一のthetaを用いることができる。

inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))

inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)

これで欲しい機能は実装できたものの、計算効率が悪く必要以上に遅い。そこで、jax.jitをラップして、同じ関数のコンパイル済みで効率的なバージョンを用いる。

perex_grads = jax.jit(inefficient_perex_grads)

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
DeviceArray([[-2.4, -4.8,  2.4],
             [-2.4, -4.8,  2.4]], dtype=float32)
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
100 loops, best of 5: 7.74 ms per loop
10000 loops, best of 5: 86.2 µs per loop

かなり早くなった。

次は、自動微分の実装について書く。

JAX学習記録②ーPytree

Preface

 この記事ではJAX学習記録シリーズ第二弾として、Pytreeについてまとめる。基本的に第一弾は読んでる前提で書くので、もし読んでない方がいたら読むことをお勧めする。Pytreeは以下の第一弾のブログでも出てきたように、モデルパラメータを格納するデータ構造(本当はそれ以外にも使える。)であり、JAXの素晴らしいポイントの一つは、jax.gradにpytreeの形式でパラメータを持つ関数を与えたとしても、各パラメータに対してその関数の勾配を計算できることである。

izmyon.hatenablog.com

 この記事は、以下のJAXおよびFlaxの公式ドキュメント読み、一部修正や改善を行いながら内容をまとめた。ちなみに、JAX編が終わったらFlax編に突入する。

目次

Pytree

What's a pytree?

 JAXでは、コンテナのようなオブジェクト(container-like object)(例えばdictのlistのlist、arrayのdictというような、ネストされた構造を持つオブジェクト)から構成されるツリー状の構造(tree-like structure)を表現するために、「pytree」という用語を用いる。クラスは、pytree container registry(デフォルトでlist, tuple, dict)にあるクラスである場合は、コンテナ的(container-like)であると見なされる。つまり、

  1. pytree container registryにないクラスのオブジェクトは、すべてpytreeのleafと見なされる。
  2. pytree container registryにあるクラスのオブジェクトは、すべてpytreeのnodeと見なされる。

 pytree container registryでは、各コンテナ型は、コンテナ型のインスタンスを(children, metadata)のペアに変換する関数と、そのペアをコンテナ型のインスタンスに変換する関数の、二つの関数のペアと共に登録されている。もちろん、ユーザーが定義したクラスをpytree container registryに新たに登録することもできる。

 pytreeの例:

[1, "a", object()]  # 3 leaves
(1, (2, 3), ())  # 3 leaves
[1, {"k1": 2, "k2": (3, 4)}, 5]  # 5 leaves

Common Pytree Functions 

 JAXは、jax.tree_util.tree_* というprefixを持つ、pytreeを扱う関数を用意している。JAXのdocumentを読むと、2022/11/09現在において、jax.tree_* を使うやり方が書いてあるが、例えばjax.tree_leaves()を使おうとすると以下の警告が出る。これによると、jax.tree_* は将来的にjax.tree_util.tree_*に移動するらしく、将来的に消される予定らしいので、jax.tree_util.tree_* をprefixに持つ関数を使うことにする。

FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.

 本筋に戻る。jax.tree_util.tree_leaves()を使うと、以下のようにtreeからleafを抽出することができる。

import jax
import jax.numpy as jnp
from jax.tree_util import tree_leaves

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

# Let's see how many leaves they have:
for pytree in example_trees:
  leaves = tree_leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x7fded60bb8c0>]   has 3 leaves: [1, 'a', <object object at 0x7fded60bb8c0>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
DeviceArray([1, 2, 3], dtype=int32)           has 1 leaves: [DeviceArray([1, 2, 3], dtype=int32)]

 先ほど紹介したようにlist, tuple, dictがnodeとして扱われ、その中の要素がleafとして認識されているが、文字列やJAX arrayなど、そのほかのオブジェクトはleafとして認識されていることが分かる。

 jax.tree_util.tree_*で最もよく使われるであろう関数は、tree_mapである。これは、Pythonネイティブのmap関数と同じように動作し、pytreeに格納したパラメータを更新する際に利用される、非常に大事な関数である。 基本的には、以下のように第一引数に関数をあたえ、第二引数以降にその入力となるpytreeを渡す。

from jax.tree_util import tree_map
list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

tree_map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

tree_mapは、複数の引数に対しても動作する。

another_list_of_lists = list_of_lists
tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

この際、引数として与えられるpytreeの構造は正確に一致している必要がある。つまり、引数として与えられたpytreeのlistは同じ数の要素を持ち、辞書は同じkeyを持つ必要がある。

Custom pytree nodes

 前述したとおり、pytree container registryには自作クラスを登録できる。登録していない場合には、自作クラスのインスタンスleafを持っていたとしても、インスタンスそのものがleafとして認識され、インスタンスが持っているleafleafとして認識されない。登録することによりインスタンス自体はnodeとして認識され、持っているleafleafとして認識されるようになる。例えば以下のように自作クラスを定義してインスタンスのリストをtree_leavesに渡すと、それぞれのインスタンスleafとして認識される。

class MyContainer:
  """A named container."""

  def __init__(self, name: str, a: int, b: int, c: int):
    self.name = name
    self.a = a
    self.b = b
    self.c = c
tree_leaves([
    MyContainer('Alice', 1, 2, 3),
    MyContainer('Bob', 4, 5, 6)
])
[<__main__.MyContainer at 0x7fdec166ce50>,
 <__main__.MyContainer at 0x7fded89ba490>]

そこで、以下のようにコンテナからleafと中間的なデータを取り出す関数と、逆にleafと中間的なデータからコンテナを作製する関数を作製し、jax.tree_util.register_pytree_node()でpytree container registerに登録する。

from typing import Tuple, Iterable
from jax.tree_util import register_pytree_node

def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:
  """Returns an iterable over container contents, and aux data."""
  flat_contents = [container.a, container.b, container.c]

  # we don't want the name to appear as a child, so it is auxiliary data.
  # auxiliary data is usually a description of the structure of a node,
  # e.g., the keys of a dict -- anything that isn't a node's children.
  aux_data = container.name
  return flat_contents, aux_data

def unflatten_MyContainer(
    aux_data: str, flat_contents: Iterable[int]) -> MyContainer:
  """Converts aux data and the flat contents into a MyContainer."""
  return MyContainer(aux_data, *flat_contents)

register_pytree_node(
    MyContainer, flatten_MyContainer, unflatten_MyContainer)

tree_leaves([
    MyContainer('Alice', 1, 2, 3),
    MyContainer('Bob', 4, 5, 6)
])
[1, 2, 3, 4, 5, 6]

以上のように、きちんと自作コンテナが持つleafが認識されていることが分かる。

 また、以下のように@register_pytree_node_classで登録したいクラスをデコレートする方法もある。

from jax.tree_util import register_pytree_node_class
from typing import Tuple, Iterable

@register_pytree_node_class
class MyContainer2:
  def __init__(self, name, a, b, c):
    self.name = name
    self.a = a
    self.b = b
    self.c = c

  def tree_flatten(self) -> Tuple[Iterable[int], str]:
    aux_data = self.name
    return (self.a, self.b, self.c), aux_data

  @classmethod
  def tree_unflatten(cls, aux_data: str, children: Iterable[int]):
    return cls(aux_data, *children)

 このやり方だとスマートに書けはするが、クラスの中にそれ自身のクラスのインスタンスからleafと中間的なデータを取り出すメソッドと、leafと中間的なデータから自身のクラスのインスタンスを作り出すメソッドを定義する必要があり少しややこしい。しかもそこまでコードが短くなるわけでもないのであまり良いやり方のようには思えない。

Pytree for ML model parameters

 この章ではPytreeをMLモデルのパラメータとして用いる方法を示すために、Linear RegressionとMLPの学習方法についてみてみる。

Linear Regression

 ここでは、前の章に引き続きLinear Regressionを行う。ただし、問題設定がより複雑で、パラメータが行列の形となる。 データ \{ (x_i, y_i), i \in \{1, \cdots, k \}, x_i \in \mathbb{R}^n, y_i \in \mathbb{R}^m \}が与えられた際、パラメータ W \in \mathbb{R}^{m×n} , b \in \mathbb{R}^mを持つモデル f_{W, b}(x) = Wx + bを考え、以下の式で表される二乗誤差を最小にするようにパラメータを学習する。

 \displaystyle  L \left( W, b \right) = \dfrac{1}{k} \sum_{i=1}^{k} \dfrac{1}{2} \left || y_i - f_{W, b}(x_i) \right ||_2^2

まずはモデルとloss関数を定義する。

def model(params, x):
  return jnp.dot(x, params['W']) + params['b']

def loss(params, x_batched,y_batched):
  def squared_error(x,y):
    y_pred = model(params, x)
    return jnp.inner(y-y_pred, y-y_pred) / 2.0
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

loss関数内でjax.vmapという関数を使っているが、これは、jaxにおいて非常に重要な関数の一つであり、また別の記事で解説する。modelはパラメータと一つのxのサンプルを受けとり、出力としてm次元のベクトルを返す関数であり、その一つのサンプルに対してlossを計算するには、loss関数内で定義されたsquared_error関数を計算すればよい。ここでjax.vmapは、それらの計算を与えられたすべてのサンプル(1バッチ)に対してまとめて行っている。そして、その平均を取ってそのバッチのlossとしている。

サンプル数を20、 xの次元を10 ( n=10)、 yの次元を5 ( m=5)とする。 xと、真のパラメータとなる W_{true},  b_{true}を標準正規分布からサンプリングし、それにさらにノイズを足して、 y = W_{true}x + b_{true} + εという形で yを生成する。

import jax
from jax import numpy as jnp, random
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W_true = random.normal(k1, (x_dim, y_dim))
b_true = random.normal(k2, (y_dim,))
true_params = {'W': W_true, 'b': b_true}

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = model(true_params, x_samples) + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
x shape: (20, 10) ; y shape: (20, 5)

これでサンプル数20のxとyを生成できた。

ここでjax.gradを使いloss関数の勾配を計算してみる。ここで、true_paramsは先ほど定義したようにvalueにパラメータを持つdict形式のpytreeとなっているが、このようにloss関数のパラメータがpytreeで与えられても、それぞれのパラメータに対して勾配を計算できるのがJAXの素晴らしいポイントである。

jax.grad(loss)(true_params, x_samples, y_samples)
{'W': DeviceArray([[ 0.03074035,  0.00365056, -0.02740437,  0.03342171,
                0.00768137],
              [ 0.00480341, -0.00601222,  0.01753974,  0.02849675,
                0.01315357],
              [ 0.02740791,  0.00509001,  0.01633988,  0.01654152,
               -0.03833748],
              [ 0.04439825,  0.01456456, -0.01553174,  0.00361327,
               -0.00520618],
              [-0.01146152,  0.01010936, -0.02445415, -0.02912016,
               -0.00614015],
              [-0.01862438, -0.01283839, -0.00399569, -0.00660799,
                0.01106221],
              [ 0.03790367,  0.01057525,  0.00876093, -0.00756639,
                0.00785039],
              [ 0.0138653 , -0.00148796, -0.02050714,  0.01991827,
                0.01749898],
              [ 0.01063178, -0.03577084, -0.02890723,  0.01893806,
                0.02628964],
              [ 0.0198197 ,  0.01517353, -0.01874115, -0.01610945,
               -0.02562965]], dtype=float32),
 'b': DeviceArray([ 0.0193719 , -0.00509552, -0.00823335, -0.01260918,
               0.00415814], dtype=float32)}

もちろん勾配は与えたパラメータと同じ形状( W \in \mathbb{R}^{m×n} , b \in \mathbb{R}^m )となるため、tree_mapの引数に更新前のパラメータを保持したpytreeと、勾配を計算した結果を格納したpytreeを渡すことができる。

それでは早速モデルを学習してみる。バッチサイズ=サンプルサイズ=20とする。パラメータ W bを0で初期化して、初期化したpytreeを新たに作る。updateでは、tree_mapを使ってパラメータの更新を行い、100エポック学習する。jax.jitでupdate関数がデコレートされているが、これはjust-in-timeコンパイルをするために必要なもので、この記事では特に触れない。計算処理を早くするおまじないのようなものである。

# Initialize estimated W and b with zeros.
W = jnp.zeros_like(W)
b = jnp.zeros_like(b)
params = {'W': W, 'b': b}

# Always remember to jit!
@jax.jit
def update(params, learning_rate, x_samples, y_samples):
  params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params,
        jax.grad(loss)(params, x_samples, y_samples))
  return params

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', loss({'W': W, 'b': b}, x_samples, y_samples))
for i in range(101):
  # Perform one gradient update.
  params = update(params, learning_rate, x_samples, y_samples)
  if (i % 5 == 0):
    print(f"Loss step {i}: ", loss(params, x_samples, y_samples))
Loss for "true" W,b:  0.023639793
Loss step 0:  10.971409
Loss step 5:  1.0798324
Loss step 10:  0.37958255
Loss step 15:  0.17855294
Loss step 20:  0.09441521
Loss step 25:  0.054522194
Loss step 30:  0.03448924
Loss step 35:  0.024058027
Loss step 40:  0.018480862
Loss step 45:  0.015438682
Loss step 50:  0.01375394
Loss step 55:  0.0128103
Loss step 60:  0.012277315
Loss step 65:  0.011974388
Loss step 70:  0.011801447
Loss step 75:  0.011702419
Loss step 80:  0.011645544
Loss step 85:  0.011612837
Loss step 90:  0.011594015
Loss step 95:  0.011583163
Loss step 100:  0.011576912

いい感じに学習できた。

また、このようにlossの値を記録したい場合には、勾配を計算する際にjax.gradではなく、jax.value_and_gradを使うと、以下のようにlossを計算し直すことなく記録できる。

# Using jax.value_and_grad instead:
loss_grad_fn = jax.value_and_grad(loss)
for i in range(101):
  # Note that here the loss is computed before the param update.
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params, grads)
    if (i % 5 == 0):
        print(f"Loss step {i}: ", loss_val)

MLP

 pytreeを使って簡単なMLPを作製して、訓練してみる。今回の例ではパラメータのpytreeはもう少し複雑になる。

import jax
import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])

ここで、layer_widths[:-1]は最後から二番目の要素までのリストであり、ここでは[1, 128, 128]のこと。layer_widths[1:]は二番目の要素から最後までのリストであり、ここでは[128, 128, 1]のことである。また、weightsは各層において標準偏差 \sqrt{\frac{2}{n_{in}}}正規分布からサンプリングされた値で初期化され、biasは1で初期化されている。 tree_mapを使ってパラメータの形状を確認する。

jax.tree_util.tree_map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

以上で、各次元が1, 128, 128, 1の四層のMLPが定義できた。

次に、forward関数とloss関数、update関数を定義する。

def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

def loss(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

LEARNING_RATE = 0.0001

@jax.jit
def update(params, x, y):
  grads = jax.grad(loss)(params, x, y)
  return jax.tree_util.tree_map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )

forward関数内で、「@」という演算が行われているが、これは行列積のこと( What’s New In Python 3.5 — Python 3.11.0 documentation)で、jnp.dotと基本的には同じ。 Linear Regressionでも説明した通り、update関数内で勾配を計算する際、第一引数として渡されているparamsについてloss関数の微分が行われ、返り値であるgradsもparamsと同じ形状をしたpytreeである。そして例のごとくupdate関数はtree_mapを使い値の更新を行い、updateされたparamsを返す。

次に、モデルの訓練を行う。正規分布からサンプリングした値からデータを作製し、1000エポック学習させる。

import matplotlib.pyplot as plt

xs = np.random.normal(size=(128, 1))
ys = xs ** 2

for _ in range(1000):
  params = update(params, xs, ys)

plt.scatter(xs, ys)
plt.scatter(xs, forward(params, xs), label='Model prediction')
plt.legend();

いい感じに学習できた。 やはりJAXの自動微分は素晴らしい!!次回は自動微分について書く。

JAX学習記録①ーJAX As Accelerated NumPy

Preface

 このシリーズは、JAXのドキュメントを読みながら学習した際に、ノート代わりに記事に残したものである。JAXを理解する際の要所は、XLAやjust-in-time、Asynchronous dispatchなど多岐にわたるが、この記事ではまず第一弾として、Numpy-likeなAPIで利用でき、自動微分にも対応するJAXのアクセラレートされたNumpyとしての機能(便宜的にJAX numpyと呼ぶ。)についてまとめる。

jax.readthedocs.io

目次

What’s JAX?

 JAXは、CPUだけでなく、GPU, TPUなどのアクセラレータ上でも動き、自動微分を兼ね備えたNumPyである。Autogradを用いて、関数の微分だけでなくさらにその微分を行ったり、順方向の微分だけでなく逆方向の微分も行い、さらに両者の合成も行うことができる。しかもJAXは、XLA(Accelerated Linear Algebra)を用いてNumpyのコードをGPUやTPU上でコンパイルして実行することができる。ライブラリの呼び出しはjust-in-timeでコンパイルされ実行され、さらに一つのAPIを使って独自のPython関数をXLAに最適化されたカーネルにjust-in-timeコンパイルすることもできる。また、コンパイルと自動微分は任意に構成でき、Pythonを離れなくても高度なアルゴリズムを表現し、最大限の性能を得ることができる。

 ....と、JAXの概要を述べたが、僕もXLAやjust-in-timeコンパイルについてはまだ理解できていない。これからこのブログを書き進めていきながら理解していこうと思っているので、諦めそうになった方がいても安心してほしい。これから一緒に理解していきましょう。

 個人的には、JAXの特筆すべき特徴としては以下があると考えている。

  • CPUだけでなく、GPU, TPU上でも動くが、実装がPythonで完結している。
  • XLAによる高速化を行っている。
  • デフォルトで非同期処理を行う。
  • 関数型プログラミングを意識し、関数的に設計されている。
  • 実際の数学と同じように関数を微分することができ、勾配ベクトルだけではなく、勾配関数を求めることができる。
  • Pythonは動的型付けだが、型付けに厳しめに実装されている。
  • 乱数の管理が独特で厳密である。

特に、PyTorchなどのように、勾配の計算などがbackward()といった形でブラックボックスになっていると個人的に非常に気持ち悪い。しかしJAXは、関数型プログラミングを意識して設計されていることもあり、loss関数を本当にパラメータとデータの関数として定義でき、しかもAutoGradによりその勾配の関数を求めることができて、勾配の計算をブラックボックスにせずにしかも数式を書くのと同じように実装することができる。個人的にJAXのこの特徴は本当に好き。大好き。(PyTorchのここら辺の話は、研究室の同期が書いたこの記事「PyTorchの微分はいかにして動くのか?」を参照のこと。)

 ここではまず、JAXのNumpyとの違いや、APIの設計指針などについて述べ、次章でJAX numpyの具体的な使い方について入門していく。

JAX vs. Numpy

 JAXのkey-conceptは、以下。

  • 簡便のため、JAXはNumPyにインスパイアされたインタフェースを持つ。
  • duck-typingにより、JAX arrayは多くの場合、Numpy arrayの代用になりうる。
  • NumPy arrayと違い、JAX arrayは常にimmutableである。

JAXは基本的にはNumpyをwrapし、機能を追加したものと捉えられるが、JAX arrayは常にimmutableであることには注意。 例えば、JAX arrayを以下のように変更しようとするとTypeErrorが出る。

x = jnp.arange(10)
x[0] = 10
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-7-6b90817377fe> in <module>()
      1 # JAX: immutable arrays
      2 x = jnp.arange(10)
----> 3 x[0] = 10

TypeError: '<class 'jax.interpreters.xla._DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

特定のindexにあるelementを変更したい場合には、以下のように、indexed update syntaxを使って複製する必要がある。

y = x.at[0].set(10)

このような違いは、JAXが関数型プログラミングのように、関数的に設計されていることに起因している。

JAX API layering

 JAXのAPI設計のkey-conceptは、以下。

  • jax.numpyは、numpyと似たインタフェースを提供するより高いレベルのwrapper
  • jax.laxは、より低いレベルのAPIで、より静的(strict)だがよりパワフル
  • すべてのJAX演算は、XLAの演算として実装されている。

 特に、jax.laxはjax.numpyと比べて静的(strict)であることには注意する必要がある。 例えば、以下の演算はjax.numpyでは実行できるが、jax.laxではTypeErrorが出る。

import jax.numpy as jnp
jnp.add(1, 1.0)
DeviceArray(2., dtype=float32)
from jax import lax
lax.add(1, 1.0)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-10-63245925fccf> in <module>()
      1 from jax import lax
----> 2 lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.

TypeError: add requires arguments to have the same dtypes, got int32, float32.

この場合、しっかりと型を指定して、以下のように書く。

lax.add(jnp.float32(1), 1.0)
DeviceArray(2., dtype=float32)

Getting Start with JAX numpy

 JAXの概要やJAX numpyとNumPyやjax.laxとの違い、JAXの設計指針に触れたとこで、ここから以下の記事を読んでJAX numpyの使い方をまとめていく。 ここでまず念頭におくべきは、冒頭でも述べた通り、JAXは基本的にはアクセラレータ上で動き、自動微分を兼ね備えたNumpyと考えてよいということである。

jax.readthedocs.io

JAX array

ベクトルを作成するのは、Numpyと同様で、arangeを用いて以下のようにする。

import jax
import jax.numpy as jnp

x = jnp.arange(10)
print(x)
x
[0 1 2 3 4 5 6 7 8 9]
DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

変数のタイプは、DeviceArrayである。 また、この際、GPU/TPUが使用できないと、以下の警告が出る。

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

このように出るのは、JAX numpyが基本的にはGPU/TPUでの使用を想定されているからである。 JAXの便利な特徴の一つは、同じコードで異なるバックエンドを使用することができることである。 GPU/TPUを使用可能にし(例えばGoogle ColabではRuntimeをGPUに変更し)、例えば以下のようにdot積を実行すると、コードの変更なく別のデバイス(ここではGPU)で実行することができる。

long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()
The slowest run took 7.39 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 5: 7.85 ms per loop

JAX first transformation: grad

 JAXの基本的な特徴として、「関数の変換を可能にする」、というものがある。 その中の最も基本的な変換がjax.gradである。これは、その名の通り、関数を与えると、勾配関数を返す。 例えば、以下のように二乗和を返す関数を定義する。

def sum_of_squares(x):
    return jnp.sum(x**2)

次にjax.gradを用いて勾配関数を求め、xの値を定義して勾配を計算すると、以下のようになる。

sum_of_squares_dx = jax.grad(sum_of_squares)
x = jnp.asarray([1.0, 2.0, 3.0, 4.0])
print(sum_of_squares(x))
print(sum_of_squares_dx(x))
30.0
[2. 4. 6. 8.]

このように、jax.gradは微分演算子  \nabla であると考えることができる。与えられた関数が f(x)であるとすると、 \nabla f fの勾配関数である。 したがって、jax.grad(f)は勾配を計算する関数であり、jax.grad(f)(x)は f xにおける勾配である。

これは、前述したようにTensorflowやPyTorchと大きく異なる特徴である。それらのライブラリでは、勾配はテンソル自体を使用してloss.backward()などで計算するが、JAX APIでは、loss関数は本当にパラメータとデータの関数であり、数学と同じようにその勾配を求めることができる。

微分する関数は複数の変数を持っている場合があるが、jax.gradは、デフォルトで第一引数として渡されている変数に対して微分を行う。 例えば、以下のような二乗誤差の勾配関数を求める場合には、第一引数として渡されている xに対して微分を行っている。

def sum_squared_error(x, y):
  return jnp.sum((x-y)**2)

sum_squared_error_dx = jax.grad(sum_squared_error)
y = jnp.asarray([1.1, 2.1, 3.1, 4.1])
print(sum_squared_error_dx(x, y))
[-0.20000005 -0.19999981 -0.19999981 -0.19999981]

異なる変数についてそれぞれ勾配関数を求めたい場合には、以下のようにargnumsを設定する。

jax.grad(sum_squared_error, argnums=(0, 1))(x, y)  # Find gradient wrt both x & y
(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),
 DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))

これだけ見ると、機械学習のように膨大なパラメータを持つ関数の勾配を求めたい場合には、いちいち引数を指定するのは非常に面倒なのではないかと思うかもしれないが、実用上は「pytree」というデータ構造でパラメータをを定義するため、その必要はなく、実際の実装は以下のようになる。(pytreeのついては、この記事では触れず、別の記事で書く。)

def loss_fn(params, data):
  ...

grads = jax.grad(loss_fn)(params, data_batch)

ここで、paramsは、例えばネストされたarrayの辞書であり、返されるgradsもまた、同じ構造を持つネストされたarrayの辞書となる。

Value and Grad

 あるデータを与えたときの関数の返り値と勾配の両方を欲しい場合には、value_and_grad関数を使う。

jax.value_and_grad(sum_squared_error)(x, y)
(DeviceArray(0.03999995, dtype=float32),
 DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))

上のように、value_and_grad関数は(value, grad)という形で、結果をタプルで返す。 つまり、vlue_and_gradは、関数 fおよびその変数*xsに対して、以下のように定義される。

jax.value_and_grad(f)(*xs) == (f(*xs), jax.grad(f)(*xs)) 

Auxiliary data

 勾配を計算する際、中間的な計算結果を得たい場合、以下のようにするとエラーが出る。

def squared_error_with_aux(x, y):
  return sum_squared_error(x, y), x-y

jax.grad(squared_error_with_aux)(x, y)
---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-9-7433a86e7375> in <module>()
      3 
----> 4 jax.grad(squared_error_with_aux)(x, y)

FilteredStackTrace: TypeError: Gradient only defined for scalar-output functions. Output was (DeviceArray(0.03999995, dtype=float32), DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).

The stack trace above excludes JAX-internal frames.

これは、jax.gradがスカラーを返す関数にのみ対応しているのに対し、tupleを返す関数を与えているからである。 この場合、引数has_auxをTrueにする。

jax.grad(squared_error_with_aux, has_aux=True)(x, y)
(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),
 DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))

has_aux = Trueは、与えられた関数は(out, aux)を返すということをjax.gradに知らせている。これにより、jax.gradはauxを無視し、outで計算した勾配とともにまとめて返す。

First JAX Training Loop

 ここまでの知識を用い、linear regressionの学習を行ってみる。 データは、 y = w_{true} x + b_{true}+ ε からサンプリングされるとする。

import numpy as np
import matplotlib.pyplot as plt

xs = np.random.normal(size=(100,))
noise = np.random.normal(scale=0.1, size=(100,))
ys = xs * 3 - 1 + noise

plt.scatter(xs, ys);

用いるモデルは、 \hat{y} (x; \theta ) = wx + bとする。 theta = [w, b]として二つのパラメータを \thetaにまとめ、modelはあくまでパラメータとデータの関数として定義する。

def model(theta, x):
  """Computes wx + b on a batch of input x."""
  w, b = theta
  return w * x + b

ロス関数は  J (x, y; \theta ) = \left( \hat{y} - y \right)^{2} とする。

def loss_fn(theta, x, y):
  prediction = model(theta, x)
  return jnp.mean((prediction-y)**2)

勾配降下法により、各更新ステップで各バラメータに対する損失の勾配を求め、最急降下する方向にわずかにパラメータを更新する。

 \theta_{new} = \theta - 0.1 \left( \nabla_{\theta} J \right) \left(x, y; \theta \right)

def update(theta, x, y, lr=0.1):
  return theta - lr * jax.grad(loss_fn)(theta, x, y)

JAXでは、ステップごとに呼び出されるupdate()関数を定義し、現在のパラメータを入力として新たなパラメータを返す。これはJAXの関数的性質の帰結なのだが、これも数式と同じように実装できて好き。

theta = jnp.array([1., 1.])

for _ in range(1000):
  theta = update(theta, xs, ys)

plt.scatter(xs, ys)
plt.plot(xs, model(theta, xs))

w, b = theta
print(f"w: {w:<.2f}, b: {b:<.2f}")
w: 3.00, b: -1.00

上で行った処理は、jax.jitでupdata()をデコレートすることで、JIT-compile(just-in-timeコンパイル)をすれば効率的に処理することが可能だが、jitに関してはまた後の記事で解説する。

JAXで実装されるトレーニングループはほとんどがこの例のように実装される。特に、のちの記事で説明するpytreeによって、パラメータをthetaにまとめjax.grad(loss_fn)(theta, x, y)のような形で勾配を計算することができた。このように数式と同じように実装できるのはほんとに好き。めっちゃ好き。ただし、今回はパラメータが二つしかなかったから一つのarrayにまとめることができたが、複雑なモデルになり多くのパラメータを扱う必要が出てくると、前述したように、パラメータは例えばネストされたarrayの辞書というような複雑な形で定義する。