izmyonの日記

奈良の山奥で研究にいそしむ大学院生の学習記録。

JAXの自動微分の仕組み ―Dougal Maclaurin氏の博士論文およびMatthew J Johnson氏の講演から①

Preface

 本シリーズでは、JAXの開発者であり、Autogradの開発者でもあるDougal Maclaurin氏と、Matthew James Johnson氏に許可をいただき、Maclaurin氏の博士論文およびJohnson氏の講演の内容から、少しずつJAXの自動微分について解説していく。

 本記事では第一弾として、Maclaurin氏の博士論文の第2.5節、Johnson氏の講演の27分当たりまで、最初のセクションの内容について書く。特に、前者については本文自体が非常にわかりやすかったためにそのまま翻訳し、後者については筆者(izmyon)の理解に基づきスライドを抜粋して解説している。

 本記事は以下の記事のより平易な解説として、補足のような位置づけとなっている。とりあえず読むことをお勧めする。

izmyon.hatenablog.com

Dougal Maclaurin氏

dougalmaclaurin.com

博士論文:Modeling, Inference and Optimization with ComposableDifferentiable Procedures

Matthew James Johnson氏

people.csail.mit.edu

講演:Automatic Differentiation: Deep Learning (DLSS) and Reinforcement Learning (RLSS) Summer School, Montreal 2017

Modeling, Inference and Optimization with ComposableDifferentiable Procedures

2 Background

2.5 Computing Gradients in Reverse Accumulation Mode

連続最適化問題や推論問題は、目的関数や確率の対数(対数尤度、情報量など)の勾配を利用することができれば、はるかに容易になる。これは特に高次元の関数、 \mathbb{R}^n \rightarrow \mathbb{R} に当てはまり、各勾配の評価はD個の関数評価を追加することと等しいからである。時間コストは関数自体の評価コストと比べたら小さな定数のファクターでしかないため、適切な評価戦略を用いることで、勾配は非常に楽に求めることができる。本節では、この評価戦略(逆(積算)モード微分法、あるいはニューラルネットワークの世界ではバックプロパゲーションと呼ばれる)を説明する。

ベクトルからスカラーへの関数、 \mathbb{R}^D \rightarrow \mathbb{R}が、既知のヤコビアンを持つ(様々なM, Nに対する)原始関数 \mathbb{R}^M \rightarrow \mathbb{R}^Nの集合で構成されている場合、その合成の勾配は、チェインルールに従って、原始関数のヤコビアンの積で与えられる。しかし、チェインルールはヤコビアンを乗じる順序を規定していない。

具体的には、 F= D \circ C \circ B \circ Aという4つの原始関数の合成として定義された F: \mathbb{R}^D \rightarrow \mathbb{R}を考える。中間値を参照できるように、この関数を分解する。

 F(x) = y \quad(x \in \mathbb{R}^D, y \in \mathbb{R} )  where \  y = D(c), \ c = C(b) , \ b= B(a), \ a = A(x) \tag{2.14}

 F F'の勾配(またはヤコビアン)は以下のように与えられる。

 \begin{align} 
F'(x) = &\frac{\partial y}{\partial x} \\
where \  &\frac{\partial y}{\partial x} =  \frac{\partial y}{\partial c}  \frac{\partial c}{\partial b}  \frac{\partial b}{\partial a}  \frac{\partial a}{\partial x} \tag{2.15} \\
&\frac{\partial y}{\partial c} = D'(c), \  \frac{\partial c}{\partial b} = C'(b), \ \frac{\partial b}{\partial a} = B'(a), \  \frac{\partial a}{\partial x} = A'(x) \tag{2.16} 
\end{align}

ここで、 A′ B′ C′および D′ A B Cおよび Dヤコビアンを計算する関数である。我々は、入力として x \in \mathbb{R}^Dを、出力として y \in \mathbb{R}を常に用いる。ここで、 yスカラーであり、対して xは(巨大かもしれない)ベクトルである。

 yは行列の乗算は結合的であるため、式2.15のヤコビアンの積を任意の順序で評価することができる。左から順に評価することを「逆積算モード」(または単にリバースモード)、右から順に評価することを 「順積算モード」(または単にフォアワードモード)と呼ぶことにする。

計算される中間値の大きさが大きく異なることに注目してほしい。 順方向モードでは、これらは \frac{\partial b}{\partial x}のようなヤコビアンである。 x \in \mathbb{R}^Dはベクトルなので、これは対応する値 b D倍の要素を含んでいる。リバースモードでは、 \frac{\partial y}{\partial b}のような値を計算する。 y \in \mathbb{R}スカラーなので、これは対応する値 bと同じ数の要素を含んでいる。

したがって、原始関数のヤコビアン A′(x) B′(a) C′(b) D′(c)を評価した後は、リバースモードが、多変数実数値関数の勾配をより効率的に評価する方法となるのである。しかし、さらにもっと良い方法がある。そもそも原始関数のヤコビアンをはじめに評価する必要さえないのだ。ヤコビアンは非常に疎なことが多く、行列の積で使うだけである。行列は、結局のところ、線形写像の表現に過ぎないため、それらをインスタンス化するのではなく、線形写像を適用する関数を直接実装すればよい。すなわち,各原始関数 A: \mathbb{R}^M \rightarrow \mathbb{R}^Nヤコビアン A': \mathbb{R}^M \rightarrow \mathbb{R}^{N \times M}に対して、左乗算のヤコビアン-ベクトル積関数(JVP)、 J_A: \mathbb{R}^M \rightarrow (\mathbb{R}^M \rightarrow \mathbb{R}^N )を(キャリーとして)以下のように書ける。

 J_A (x, g) = A'(x)g \tag{2.19}

右乗算のベクター-ヤコビアン積関数(VJP)、 J_A^T: \mathbb{R}^M \rightarrow (\mathbb{R}^N \rightarrow \mathbb{R}^M)は、以下のように書ける。

 J_A^T (x, g) = gA'(x) \tag{2.20}

例えば、各要素を二乗する関数を考える。

 ElemSquare(x) = x \odot x \tag{2.21}

ここで,記号 \odotは要素ごとの乗算を表す。ElemSquareは非常に疎なヤコビアンを持ち、対角線上に 2x、その他の部分に 0を持つ単なる行列である。VJP関数は以下のように与えらる。

 J^T_{ElemSquare} (x, g) = 2g \odot x \tag{2.22}

ヤコビアンは対称行列であるので、左乗算JVP関数 J_{ElemSquare}も同様になる。

図2.2に示すように、JVPを連結することで、順方向と逆方向の両方の微分を実現することができる。順方向モードでは、各ステップで各入力次元ごとにJVPを適用するが、逆方向モードでは、各ステップでVJPを1回だけ適用する。JVPやVJPの評価は、通常、原始関数そのものを評価するよりも小さな定数倍(1〜3)だけ遅くなる。

図2.2:合成関数 F: \mathbb{R}^D \rightarrow \mathbb{R} F = C \cdot B \cdot Aに対するフォアワードモードとリバースモードの微分の違いを図示した。フォアワードモードは入力 xに対する各中間変数のヤコビアン、例えば \frac{\partial a}{\partial x} \frac{\partial b}{\partial x} のような値を積算する。これは、前のステップの積算ヤコビアンと現在の原始関数のヤコビアンで左掛けすることによって行う。リバースモードはヤコビアンを逆向きに乗算することで動作する。これは、前ステップの蓄積されたヤコビアンを現在の原始関数のヤコビアンで右掛けすることにより、中間変数それぞれに対する出力 yヤコビアン \frac{\partial y}{\partial a} \frac{\partial y}{\partial b} といった値を蓄積する。 y \in \mathbb{R} x \in \mathbb{R}^D Dが大きい場合、リバースモードで蓄積される値は Dの因子だけ小さくなり、計算の効率が大幅に向上する。

したがって、フォアワードモードの微分はリバースモードの微分よりも D x \in \mathbb{R}^D)倍遅くなり、リバースモードの微分は合成関数そのものを評価するより小さな定数倍だけ遅くなる。ここで、リバースモードの微分には1つの大きな欠点があることに注意しなければならない。 というのも、リバースパスでJVPを適用する前に、完全なフォワードパスで中間値を計算する必要があるため、すべての中間値をメモリに格納する必要があるからである。このことは、第六章で説明するように、時として大問題となることがある。

 逆積算モードの微分を用いて原始関数の合成鎖の勾配を効率的に計算する方法を説明したが、一般に、合成関数は原始関数の有向非巡回グラフとして記述することができる。幸いなことに、鎖として考えるための戦略は、グラフにも簡単に適用できる。鎖の場合と同様に、関数を評価するために完全なフォアワードパスを計算し、すべての中間値を格納する。次にグラフを逆に走査し、各中間値 zについてヤコビアン-ベクトル積を適用して \frac{\partial y}{\partial z} を計算する。

チェインルールによる合成では発生しないが、対処すべき追加のケースがある。ある値が複数回使用される「ファンアウト」と、関数が複数の入力を受け取る 「ファンイン」である。ファンインは、関数の各引数に対してヤコビアン-ベクトル積関数を定義することで対処する。変数 zの再利用であるファンアウトは、 zを利用する各分岐 iに対して \frac{\partial y}{\partial z}^{(i)} を計算し、その結果を合計して完全な \frac{\partial y}{\partial z} を得ることによって対処する。この場合、グラフを走査する順番に制約が生じ、すべての \frac{\partial y}{\partial z}^{(i)} が利用可能でなければ続行できない。図2.3にこの2つのケースを示す。

リバースモードの微分は、様々な量的分野で独自に何度も発見され[9]、機械学習でも何度か発見されている[114]。逆伝搬の(再)発明は Rumelhartら[93]のものが最も有名である。私自身の理解は、PearlmutterとSiskindの仕事(例えば[79])によって大きく形成されたものである。

図2.3:逆積算モードの計算グラフを図示する。ファンアウト(値 zの再利用)とファンイン(2つの引数を関数 Cに与える)の処理方法を示す。

Reference

[9] Baydin, A. G., Pearlmutter, B. A., & Radul, A. A. (2015). Automatic differentiation in machine learning: a survey. CoRR, abs/1502.05767.

[79] Pearlmutter, B. A. & Siskind, J. M. (2008). Reverse-mode AD in a functional framework: Lambda the ultimate backpropagator. ACM Transactions on Programming Languages and Systems (TOPLAS), 30(2), 7.

[93] Rumelhart, D. E., Hinton, G. E., & Williams, R. J. (1986). Learning repres by back-propagating errors. Nature, 323, 533–536.

[114] Widrow, B. & Lehr, M. A. (1990). 30 years of adaptive neural networks: perceptron, madaline, and backpropagation. Proceedings of the IEEE, 78(9), 1415–1442.

Automatic Differentiation: Deep Learning (DLSS) and Reinforcement Learning (RLSS) Summer School, Montreal 2017

1. Jacobians and the chain rule ―Forward and reverse accumulation

図のように入力 \boldsymbol{x} \in \mathbb{R}^nを受け取り、 y \in \mathbb{R}を出力する関数 Fを考える。 F A B C Dを原始関数とする合成関数である。また、 \boldsymbol{x}を入力した際の各原始関数の出力を、それぞれ、 \boldsymbol{a} = A(\boldsymbol{x})  \boldsymbol{b} = B(\boldsymbol{a}) \boldsymbol{c} = C(\boldsymbol{b}) y = D(\boldsymbol{c})とする。

関数 F微分は、チェインルールにより、 \frac{\partial y}{\partial \boldsymbol{c}} \frac{\partial \boldsymbol{c}}{\partial \boldsymbol{b}} \frac{\partial \boldsymbol{b}}{\partial \boldsymbol{a}} \frac{\partial \boldsymbol{a}}{\partial \boldsymbol{x}} の積で表せる。そして、それらのサイズは、それぞれ、 1 \times ( size \ of \ \boldsymbol{c} ) ( size \ of \ \boldsymbol{c} ) \times ( size \ of \ \boldsymbol{b} ) ( size \ of \ \boldsymbol{b} ) \times ( size \ of \ \boldsymbol{a} ) ( size \ of \ \boldsymbol{a} ) \times \ nである。入力データ \boldsymbol{x}のサイズ nは、大きくなりがちである。その場合、サイズが ( size \ of \ \boldsymbol{a} ) \times \ n \frac{\partial \boldsymbol{a}}{\partial \boldsymbol{x}}も大きくなる。

チェインルールでは計算の順序についての制約はないため、各原始関数のヤコビアンの合成順序は自由である。順積算モードでは、各ヤコビアンを右側から順に合成する。逆積算モードでは、各ヤコビアンを左から合成する。この時、各ヤコビアンの合成では単に行列積を計算することを思い出してほしい(参考:連鎖律(多変数関数の合成関数の微分) | 高校数学の美しい物語)。例えば、順積算モードでは、図のようにまず \frac{\partial \boldsymbol{b}}{\partial \boldsymbol{a}} \frac{\partial \boldsymbol{a}}{\partial \boldsymbol{x}}を計算するが、この計算はサイズ ( size \ of \ \boldsymbol{b} ) \times ( size \ of \ \boldsymbol{a} )とサイズ ( size \ of \ \boldsymbol{a} ) \times \ nの行列積であり、サイズ ( size \ of \ \boldsymbol{b} ) \times \ n \frac{\partial \boldsymbol{b}}{\partial \boldsymbol{x}}となる。その後、左から新たなヤコビアンを掛けていくことになるが、そのたびに得られる積算ヤコビアンは、行方向のサイズが常に nとなる。したがって、通常 nが大きいために、順積算モードでは、計算結果の途中で得られる積算ヤコビアンの形状が大きくなってしまう。一方、逆積算モードでは、 y \in \mathbb{R}であるために、積算ヤコビアンの列方向のサイズは常に 1となり、単なる行ベクトルとなる。ゆえに、逆積算モードの方が計算が楽に、メモリも少なく計算することができ、高速に微分することができる。

順積算モードの微分では、JVPを再帰的に計算することで求めることができ、最初に与えるベクトルをサイズ n \times n単位行列ヤコビアン \frac{\partial \boldsymbol{b}}{\partial \boldsymbol{x}}にして計算を開始し、順次、計算結果である積算ヤコビアンをベクトルに、左のヤコビアンヤコビアンとして新たにJVPを計算していくことで最終的に合成関数 Fの勾配を求めることができる。ここで、実際には計算の開始時に単位行列を行ごとのワンホットベクトルに分け、データ \boldsymbol{x}の各次元ごとにJVPを n回計算する。(図2.2上 参照)

逆積算モードの微分では、VJP積を再帰的に計算することで求めることができ、最初に与えるベクトルをサイズ 1 \times 1単位行列ヤコビアン \frac{\partial \boldsymbol{y}}{\partial \boldsymbol{c}}にして計算を開始し、順次、計算結果である積算ヤコビアンをベクトルに、右のヤコビアンヤコビアンとして新たなVJPを計算していくことで最終的に合成関数 Fの勾配を求めることができる。ここで、逆方向モードでは積算ヤコビアンが常に行ベクトルになるために、順積算モードのように \boldsymbol{x}の各次元ごとにヤコビアンを計算することはせず、一度VJPを適用するだけでよい。(図2.2下 参照)したがって、ナイーブに考えれば、順積算モードと比べて計算量は \frac{1}{n}で済む。

ある値が複数回使用される「ファンアウト」と、関数が複数の入力を受け取る 「ファンイン」では、連鎖が分岐してしまい、計算グラフが分岐してしまうため、特別な対応が必要となる。ファンインは、単に各引数に対してJVPを定義して引数ごとにヤコビアンを求め、単にスタックする。ある変数 xの再利用であるファンアウトは、 xを利用する各分岐 iに対して \frac{\partial y}{\partial x}^{(i)} を計算し、その結果を合計して完全な \frac{\partial y}{\partial x} を得ることによって対処する。ここでは、単純な線形変換を行う G(\boldsymbol{x})を考える。そのVJPを考える場合は、単純に(転置)ベクトルの各成分を足し合わせることで求めることができる。