Preface
この記事ではJAX学習記録シリーズ第四弾として、JAXの自動微分がどのように実現されているか、詳細を見てみる。正直なところ、特に数学的な説明に関して理解しきれていないところが多く、基本的には原文に忠実に訳しており、自分の理解による説明はあまり加えられていない。もし修正が必要な個所があったり、もっと分かりやすい説明の仕方があれば、修正案または補足案と共に教えて頂けると非常にありがたい。また、この記事は以下の第三弾の続きであり、これを読んでる前提で書くのでまだ読んでいない方は読むことをお勧めする。
今回の内容の元記事は、以下。
目次
- Preface
- How it's made: two foundational autodiff functions
How it's made: two foundational autodiff functions
JAXはforward-mode、reverse-modeの自動微分を持つ。ちなみに、おなじみのjax.gradは、reverse-modeをベースとして作られている。この章では、これら二つのmodeの違いを説明し、それをうまく使うための説明をおこなう。
Jacobian-Vector products (JVPs, aka forward-mode autodiff)
JVPs in math
数学的には、ある関数が与えられたとき、で表されるある入力点におけるのヤコビアン(日本語では行列自体はヤコビ行列と呼び、その行列式をヤコビアンと呼んで区別することもあるが、ここでは原文に従い、行列の方をヤコビアンと呼ぶことで統一する。また、行列式の方は出てこない。)は、多くの場合の行列の形をしている。は、空間における入力点の接空間を、空間における出力点の接空間に移す線形写像となる。
この変換はのにおける押し出し(pushforward) と呼ばれ、ヤコビアンはこの線形写像を標準基底で表したものである。もし、特定の入力点にこだわらないのであれば、を最初に入力点を取り、その点のヤコビアンを線形写像として返す関数であると考えることができる。
特に、入力点と接(勾配)ベクトルが与えられたとき、出力の接ベクトルを得ることができると考えることができる。このペアから出力となる接ベクトルへの写像を、ヤコビアン-ベクトル積と呼び、次のように書く。
JVPs in JAX code
JAXでは、jvpという関数でこの変換を実装している。数学関数を評価するPython関数が与えられたとき、JAXのjvpは、を評価するPython関数を返す。
from jax import jvp # Isolate the function from the weight matrix to the predictions f = lambda W: predict(W, b, inputs) key, subkey = random.split(key) v = random.normal(subkey, W.shape) # Push forward the vector `v` along `f` evaluated at `W` y, u = jvp(f, (W,), (v,))
Haskellのように型を定義するとするなら、以下のようになる。
jvp :: (a -> b) -> a -> T a -> (b, T b)
つまり、jvpはa->b型の関数、a型の値、T a型の接ベクトルを引数にとり、b型の値とT b型の出力接ベクトルの組を返す。
jvp変換された関数は、元の関数と同様に評価されるが、a型の各原始値と対になり、T a型の接ベクトル値を押し出す。元の関数が適用するはずの各原始数値演算に対して、jvp変換された関数は、原始値に対する原始値の評価と、それらの原始値における原始値のJVP適用という「JVP規則」を実行する。
この評価方法は、計算量に影響を与える。JVPを評価しながら計算するので、後で保存する必要がなく、メモリコストは計算の深さに関係なく発生する。さらに、jvp変換された関数のFLOPコストは、関数を評価するだけのコストの約3倍(sin(x)などの元の関数の評価に1ユニット、cos(x)などの線形化に1ユニット、cos_x * vなど線形化した関数をベクトルに適用して1ユニットが必要)となる。別の言い方をすれば、固定原点に対して、を評価するのとほぼ同じ限界コストでを評価することができる。
このメモリの複雑さは、かなり説得力がある。では、なぜ機械学習ではforward-modeをあまり見かけないのだろうか?
その答えとして、まず、JVPを使って完全なヤコビアン行列を構築する方法を考えてみる。もし、JVPを1つのone-hotの接ベクトルに適用すると、ヤコビアン行列の1列が明らかになり、入力した非ゼロの要素に対応する。つまり、一度に1列ずつ完全なヤコビアンを構築することができ、各列を得るのに、1回の関数評価と同じぐらいのコストがかかる。これは「高い」(行方向の要素数が列方向の要素数より多い。)ヤコビアンを持つ関数には効率的だが、「広い」(行方向の要素数が列方向の要素数より少ない。)ヤコビアンには非効率的である。
機械学習で勾配に基づく最適化を行う場合、おそらくはパラメータとデータからスカラー損失値を求める損失関数を最小化したいと思うことが多いだろう。これは、この関数のヤコビアンが非常に広い行列であることを意味する。この行列を1列ずつ構築し、それぞれの呼び出しで元の関数を評価するために同数のFLOPsを必要とするのは、確かに非効率的である。特に、学習損失関数が数百万から数十億にもなるニューラルネットワークを学習する場合、この方法では、スケールアップは望めない。
これらの理由から、reverse-modeを用いる。
Vector-Jacobian products (VJPs, aka reverse-mode autodiff)
forward-modeがヤコビアン-ベクトル積を評価する関数を返し、それを使ってヤコビアンを一列ずつ作ることができたのに対し、reverse-modeでは、ベクトル-ヤコビアン積(ヤコビアン-転置ベクトル積と同等)の評価関数を返す方法で、一行ずつヤコビアンを作ることができる。
VJPs in math
もう一度関数について考える。VJPは以下で表される。
ここで、はのにおける余接空間の要素である。厳密には、は線形写像であると考えるべきであり、と書くと、合成関数を意味していると考える。しかし、多くの場合、は、単にのベクトルであると考え、両者を互換的に用いることができる。
また、VJPの線形部はJVPの線形部の転置(または随伴行列)であると考えることもできる。
余接空間上に対応する写像は、しばしばのにおける引き戻し(pullback)と呼ばれる。重要なのは、の出力のようなものから、入力のようなものへ引き戻すことである。
VJPs in JAX code
JAX変換vjpでは、数学関数を評価するPython関数が与えられたとき、VJP を評価するPython関数を返すということである。
from jax import vjp # Isolate the function from the weight matrix to the predictions f = lambda W: predict(W, b, inputs) y, vjp_fun = vjp(f, W) key, subkey = random.split(key) u = random.normal(subkey, y.shape) # Pull back the covector `u` along `f` evaluated at `W` v = vjp_fun(u)
Haskellのように型を定義するとするなら、以下のようになる。
vjp :: (a -> b) -> a -> (b, CT b -> CT a)
つまり、vjpはa -> b型の関数とa型の点を引数にとり、b型の値とCT b -> CT a型の線形写像からなるペアを返す。
これは、ヤコビアン行列を一度に1行ずつ構築することを可能にし、[tex: (x, v) \rightarrow (f(x), vT \partial f(x))]を評価するためのFLOPコストは、の評価のコストの約3倍となる。特に、ある関数の勾配を求める場合、1回の呼び出しで求めることができる。このようにgradは、数百万から数十億のパラメータを持つニューラルネットワークの学習損失関数のような目的に対しても、勾配に基づく最適化を効率的に行うことができる。
しかし、gradにはコストがかかる。FLOPsが少ないとはいえ、メモリは計算の深さに比例して増加する。また、JAXにはいくつかのトリックがあるが、伝統的に実装は順方向モードよりも複雑となる。
reverse-modeの仕組みについては、2017年に開催されたDeep Learning Summer Schoolのチュートリアルビデオを観ると良い。
Hessian-vector products using both forward- and reverse-mode
前の記事(第三弾)で、ヘッシアン-ベクトル積関数を、(二階導関数の連続性を仮定して、)reverse-modeで実装した。
def hvp(f, x, v): return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
この実装は効率的ではあるが、forward-modeとreverse-modeを組み合わせることで、メモリを節約しつつより効率的に実装できる。
数学的には、微分する関数、関数を線形化する点、ベクトルが与えられたとき、ヘッシアン-ベクトル積関数は以下のようになる。
の微分(または勾配)をとすると、である。求めるのはそのJVPであり、
を得る。
from jax import jvp, grad # forward-over-reverse def hvp(f, primals, tangents): return jvp(grad(f), primals, tangents)[1]
さらに良いことに、jnp.dotを直接呼び出す必要がないため、このhvp関数はpytreeのように、どんな形の配列でも、どんなコンテナ型でも動作する。しかも、jax.numpyに依存することさえない。
以下は、hvp関数の使用例である。
def f(X): return jnp.sum(jnp.tanh(X)**2) key, subkey1, subkey2 = random.split(key, 3) X = random.normal(subkey1, (30, 40)) V = random.normal(subkey2, (30, 40)) ans1 = hvp(f, (X,), (V,)) ans2 = jnp.tensordot(hessian(f)(X), V, 2) print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True
以前定義したhessian関数でも、hvp関数でも、ほぼ同様の答えが得られていることが分かる。
もう一つの方法として、reverse-over-forwardを使う方法がある。
# reverse-over-forward def hvp_revfwd(f, primals, tangents): g = lambda primals: jvp(f, primals, tangents)[1] return grad(g)(primals)
forward-modeは、reverse-modeよりもオーバーヘッドが少なく、また、ここでの外側の微分演算子は内側の演算子よりも多くの微分しなければならないので、forward-modeを外側にしておくことが最もうまくいくのである。
# reverse-over-reverse, only works for single arguments def hvp_revrev(f, primals, tangents): x, = primals v, = tangents return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) print("Forward over reverse") %timeit -n10 -r3 hvp(f, (X,), (V,)) print("Reverse over forward") %timeit -n10 -r3 hvp_revfwd(f, (X,), (V,)) print("Reverse over reverse") %timeit -n10 -r3 hvp_revrev(f, (X,), (V,)) print("Naive full Hessian materialization") %timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse 5.45 ms ± 187 µs per loop (mean ± std. dev. of 3 runs, 10 loops each) Reverse over forward 9.52 ms ± 4.27 ms per loop (mean ± std. dev. of 3 runs, 10 loops each) Reverse over reverse 14.1 ms ± 7.22 ms per loop (mean ± std. dev. of 3 runs, 10 loops each) Naive full Hessian materialization 53.6 ms ± 831 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Composing VJPs, JVPs, and vmap
Jacobian-Matrix and Matrix-Jacobian products
これで、jvpとvjp関数がそろい、一度に一つのベクトルを押し出し(pushforward)たり、引き戻し(pullback)たりすることができるようになった。ここで、一番最初に紹介したJAXのvmap変換を使うと、基底全体を押し出したり、引き戻したりすることができるようになる。特に、行列-ヤコビアン積およびヤコビアン-行列積を早く計算できるようになる。
# Isolate the function from the weight matrix to the predictions f = lambda W: predict(W, b, inputs) # Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`. # First, use a list comprehension to loop over rows in the matrix M. def loop_mjp(f, x, M): y, vjp_fun = vjp(f, x) return jnp.vstack([vjp_fun(mi) for mi in M]) # Now, use vmap to build a computation that does a single fast matrix-matrix # multiply, rather than an outer loop over vector-matrix multiplies. def vmap_mjp(f, x, M): y, vjp_fun = vjp(f, x) outs, = vmap(vjp_fun)(M) return outs key = random.PRNGKey(0) num_covecs = 128 U = random.normal(key, (num_covecs,) + y.shape) loop_vs = loop_mjp(f, W, M=U) print('Non-vmapped Matrix-Jacobian product') %timeit -n10 -r3 loop_mjp(f, W, M=U) print('\nVmapped Matrix-Jacobian product') vmap_vs = vmap_mjp(f, W, M=U) %timeit -n10 -r3 vmap_mjp(f, W, M=U) assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product 132 ms ± 398 µs per loop (mean ± std. dev. of 3 runs, 10 loops each) Vmapped Matrix-Jacobian product 5.53 ms ± 109 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
def loop_jmp(f, W, M): # jvp immediately returns the primal and tangent values as a tuple, # so we'll compute and select the tangents in a list comprehension return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M]) def vmap_jmp(f, W, M): _jvp = lambda s: jvp(f, (W,), (s,))[1] return vmap(_jvp)(M) num_vecs = 128 S = random.normal(key, (num_vecs,) + W.shape) loop_vs = loop_jmp(f, W, M=S) print('Non-vmapped Jacobian-Matrix product') %timeit -n10 -r3 loop_jmp(f, W, M=S) vmap_vs = vmap_jmp(f, W, M=S) print('\nVmapped Jacobian-Matrix product') %timeit -n10 -r3 vmap_jmp(f, W, M=S) assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product 246 ms ± 494 µs per loop (mean ± std. dev. of 3 runs, 10 loops each) Vmapped Jacobian-Matrix product 2.94 ms ± 20.9 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
The implementation of jacfwd and jacrev
前の節で高速なヤコビアン-行列積、行列-ヤコビアン積を見てきた。同じ手法で標準基底全体を一度に押し出し(pushforward)または引き戻し(pullback)することで、jacfwdとfacrevを実装する。
from jax import jacrev as builtin_jacrev def our_jacrev(f): def jacfun(x): y, vjp_fun = vjp(f, x) # Use vmap to do a matrix-Jacobian product. # Here, the matrix is the Euclidean basis, so we get all # entries in the Jacobian at once. J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y))) return J return jacfun assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd def our_jacfwd(f): def jacfun(x): _jvp = lambda s: jvp(f, (x,), (s,))[1] Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x))) return jnp.transpose(Jt) return jacfun assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'
興味深いのは、AutoGradでは、これはできなかったことである。Autogradでのreverse-modeでのヤコビアンの実装は、map関数のループで一度に一つのベクトルを引き戻す(pullback)する必要があった。しかし、一度に一つのベクトルを計算で押し出す(pushforward)ことは、vmapですべてのバッチ処理をするよりもはるかに効率が悪いのである。
もう一つ、AutoGradでできなかったことは、jit(just-in-timeコンパイル)である。面白いことに、微分する関数にどれだけPythonのダイナミズムを使っても、計算の線形部分には常にjitを使うことができる。
def f(x): try: if x < 3: return 2 * x ** 3 else: raise ValueError except ValueError: return jnp.pi * x y, f_vjp = vjp(f, 4.) print(jit(f_vjp)(1.))
(DeviceArray(3.1415927, dtype=float32, weak_type=True),)
Complex numbers and differentiation
JAXは複素数と、その微分も扱うことができる。正則関数と非正則関数の両方の微分を行うには、JVPとVJPの両方の観点から考える必要がある。
まず、複素数から複素数への関数を考え、それに対応する関数を同定する。
def f(z): x, y = jnp.real(z), jnp.imag(z) return u(x, y) + v(x, y) * 1j def g(x, y): return (u(x, y), v(x, y))
ここでは、複素数を受けとり、を返す関数と、複素数空間をと見なす関数を定義した。 関数は実数の入力と出力しか持たないため、例え接ベクトル を使ったヤコビアン-ベクトル積は以下のようになる。
元の関数を接ベクトルに適用してJVPを得るには、同じ定義を使って、結果を別の複素数として特定するだけで良い。
これが関数JVPの定義である。ここで、が正則かどうかは関係ない。
以下で確認してみる。
def check(seed): key = random.PRNGKey(seed) # random coeffs for u and v key, subkey = random.split(key) a, b, c, d = random.uniform(subkey, (4,)) def fun(z): x, y = jnp.real(z), jnp.imag(z) return u(x, y) + v(x, y) * 1j def u(x, y): return a * x + b * y def v(x, y): return c * x + d * y # primal point key, subkey = random.split(key) x, y = random.uniform(subkey, (2,)) z = x + y * 1j # tangent vector key, subkey = random.split(key) c, d = random.uniform(subkey, (2,)) z_dot = c + d * 1j # check jvp _, ans = jvp(fun, (z,), (z_dot,)) expected = (grad(u, 0)(x, y) * c + grad(u, 1)(x, y) * d + grad(v, 0)(x, y) * c * 1j+ grad(v, 1)(x, y) * d * 1j) print(jnp.allclose(ans, expected))
check(0) check(1) check(2)
True True True
次にVJPについてみてみる。余接ベクトルに対して、のVJPを以下のように定義する。
ここで、複素共役と余接を扱っているために、マイナスが出てきている。
以下で確認してみる。
def check(seed): key = random.PRNGKey(seed) # random coeffs for u and v key, subkey = random.split(key) a, b, c, d = random.uniform(subkey, (4,)) def fun(z): x, y = jnp.real(z), jnp.imag(z) return u(x, y) + v(x, y) * 1j def u(x, y): return a * x + b * y def v(x, y): return c * x + d * y # primal point key, subkey = random.split(key) x, y = random.uniform(subkey, (2,)) z = x + y * 1j # cotangent vector key, subkey = random.split(key) c, d = random.uniform(subkey, (2,)) z_bar = jnp.array(c + d * 1j) # for dtype control # check vjp _, fun_vjp = vjp(fun, z) ans, = fun_vjp(z_bar) expected = (grad(u, 0)(x, y) * c + grad(v, 0)(x, y) * (-d) + grad(u, 1)(x, y) * c * (-1j) + grad(v, 1)(x, y) * (-d) * (-1j)) assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0) check(1) check(2)
True True True
それでは、grad、jacfwd、jacrevのようなラッパーに関してはどうだろうか?
の関数の場合、grad(f)(x)をvjp(f, x)[1](1.0)と定義したことを踏まえると、VJPを1.0の値に適用させると勾配(またはヤコビアンや微分)を求めることができる。の関数の場合でも同様であり、1.0を余接ベクトルとして利用し、ヤコビアンを複素数で表した結果を得ることができる。
def f(z): x, y = jnp.real(z), jnp.imag(z) return x**2 + y**2 z = 3. + 4j grad(f)(z)
DeviceArray(6.-8.j, dtype=complex64)
一般的な関数では、ヤコビアンは実数値の自由度が4である(上で定義した2×2のヤコビアンのように)ため、そのすべてを複素数で表現することは望めない。しかし、正則関数であればそれが可能である。正則関数は、その微分が一つの複素数表現できるという特殊な性質を持っている特殊な関数である(コーシー・リーマン方程式が、上記の2×2のヤコビアンが、複素数平面上でスケールと回転の行列という特殊な形式、つまり、乗算を行うことで、1つの複素数の作用を持つことを示している)。そして、1.0の余ベクトル(covector)を持つvjpを一度呼び出すだけで、その複素数を求めることができる。
これは正則関数にしか使えないため、これを利用するには、以下のようにholomorphic=Trueとし、関数が正則であることをJAXに約束する。もしそうでなければ、JAXはgradを複素数が出力の関数に使ったときにエラーを出す。
def f(z): return jnp.sin(z) z = 3. + 4j grad(f, holomorphic=True)(z)
DeviceArray(-27.034946-3.8511534j, dtype=complex64, weak_type=True)
holomorphic=Trueの約束をすることで、出力が複素数値であるときのエラーを無効にする。関数が正則ではない場合にも、holomorphic=Trueと書けるが、出力されるのは完全なヤコビアンを表すものではなく、出力の虚数部を捨てた関数のヤコビアンが出力される。
def f(z): return jnp.conjugate(z) z = 3. + 4j grad(f, holomorphic=True)(z) # f is not actually holomorphic!
DeviceArray(1.-0.j, dtype=complex64, weak_type=True)
ここでは、gradの働きについて、いくつかの有用な知見を示す。
- 正則関数にgradを使うことができる。
- 例えば、複素数パラメータxの実数値損失関数のように、grad(f)(x)の共役の方向にパラメータを更新することによって、関数を最適化するためにgradを使用することができる。
- もし、内部でたまたま複素数演算(畳み込みに使われるFFTのような非正則演算)を使っている関数があったとしても、gradは機能し、実数値だけを使った実装と同じ結果を得ることができる。
いずれにしろ、JVPとVJPは常に曖昧性を持たず、もし非正則関数の完全なヤコビアンを計算したい場合、JVPやVJPで計算することができる。
JAXでは、複素数はどこでも使えると思ってよい。以下は、複素数行列をコレスキー分解で微分している。
A = jnp.array([[5., 2.+3j, 5j], [2.-3j, 7., 1.+7j], [-5j, 1.-7j, 12.]]) def f(X): L = jnp.linalg.cholesky(X) return jnp.sum((L - jnp.sin(L))**2) grad(f, holomorphic=True)(A)
DeviceArray([[-0.7534182 +0.j , -3.0509028 -10.940544j, 5.9896846 +3.542303j], [-3.0509028 +10.940544j, -8.904491 +0.j , -5.1351523 -6.559373j], [ 5.9896846 -3.542303j, -5.1351523 +6.559373j, 0.01320427 +0.j ]], dtype=complex64)
次は、jitについて書く。