Preface
本シリーズでは、JAXの開発者であり、Autogradの開発者でもあるDougal Maclaurin氏と、Matthew James Johnson氏に許可をいただき、Maclaurin氏の博士論文およびJohnson氏の講演の内容から、少しずつJAXの自動微分について解説していく。
本記事では第一弾として、Maclaurin氏の博士論文の第2.5節、Johnson氏の講演の27分当たりまで、最初のセクションの内容について書く。特に、前者については本文自体が非常にわかりやすかったためにそのまま翻訳し、後者については筆者(izmyon)の理解に基づきスライドを抜粋して解説している。
本記事は以下の記事のより平易な解説として、補足のような位置づけとなっている。とりあえず読むことをお勧めする。
Dougal Maclaurin氏
博士論文:Modeling, Inference and Optimization with ComposableDifferentiable Procedures
Matthew James Johnson氏
- Preface
- Modeling, Inference and Optimization with ComposableDifferentiable Procedures
- Automatic Differentiation: Deep Learning (DLSS) and Reinforcement Learning (RLSS) Summer School, Montreal 2017
Modeling, Inference and Optimization with ComposableDifferentiable Procedures
2 Background
2.5 Computing Gradients in Reverse Accumulation Mode
連続最適化問題や推論問題は、目的関数や確率の対数(対数尤度、情報量など)の勾配を利用することができれば、はるかに容易になる。これは特に高次元の関数、に当てはまり、各勾配の評価はD個の関数評価を追加することと等しいからである。時間コストは関数自体の評価コストと比べたら小さな定数のファクターでしかないため、適切な評価戦略を用いることで、勾配は非常に楽に求めることができる。本節では、この評価戦略(逆(積算)モード微分法、あるいはニューラルネットワークの世界ではバックプロパゲーションと呼ばれる)を説明する。
ベクトルからスカラーへの関数、が、既知のヤコビアンを持つ(様々なM, Nに対する)原始関数の集合で構成されている場合、その合成の勾配は、チェインルールに従って、原始関数のヤコビアンの積で与えられる。しかし、チェインルールはヤコビアンを乗じる順序を規定していない。
具体的には、という4つの原始関数の合成として定義されたを考える。中間値を参照できるように、この関数を分解する。
、の勾配(またはヤコビアン)は以下のように与えられる。
ここで、、、およびは、、およびのヤコビアンを計算する関数である。我々は、入力としてを、出力としてを常に用いる。ここで、はスカラーであり、対しては(巨大かもしれない)ベクトルである。
は行列の乗算は結合的であるため、式2.15のヤコビアンの積を任意の順序で評価することができる。左から順に評価することを「逆積算モード」(または単にリバースモード)、右から順に評価することを 「順積算モード」(または単にフォアワードモード)と呼ぶことにする。
計算される中間値の大きさが大きく異なることに注目してほしい。 順方向モードでは、これらはのようなヤコビアンである。はベクトルなので、これは対応する値の倍の要素を含んでいる。リバースモードでは、のような値を計算する。はスカラーなので、これは対応する値と同じ数の要素を含んでいる。
したがって、原始関数のヤコビアン、、、を評価した後は、リバースモードが、多変数実数値関数の勾配をより効率的に評価する方法となるのである。しかし、さらにもっと良い方法がある。そもそも原始関数のヤコビアンをはじめに評価する必要さえないのだ。ヤコビアンは非常に疎なことが多く、行列の積で使うだけである。行列は、結局のところ、線形写像の表現に過ぎないため、それらをインスタンス化するのではなく、線形写像を適用する関数を直接実装すればよい。すなわち,各原始関数とヤコビアンに対して、左乗算のヤコビアン-ベクトル積関数(JVP)、を(キャリーとして)以下のように書ける。
右乗算のベクター-ヤコビアン積関数(VJP)、は、以下のように書ける。
例えば、各要素を二乗する関数を考える。
ここで,記号は要素ごとの乗算を表す。ElemSquareは非常に疎なヤコビアンを持ち、対角線上に、その他の部分にを持つ単なる行列である。VJP関数は以下のように与えらる。
ヤコビアンは対称行列であるので、左乗算JVP関数も同様になる。
図2.2に示すように、JVPを連結することで、順方向と逆方向の両方の微分を実現することができる。順方向モードでは、各ステップで各入力次元ごとにJVPを適用するが、逆方向モードでは、各ステップでVJPを1回だけ適用する。JVPやVJPの評価は、通常、原始関数そのものを評価するよりも小さな定数倍(1〜3)だけ遅くなる。
したがって、フォアワードモードの微分はリバースモードの微分よりも()倍遅くなり、リバースモードの微分は合成関数そのものを評価するより小さな定数倍だけ遅くなる。ここで、リバースモードの微分には1つの大きな欠点があることに注意しなければならない。 というのも、リバースパスでJVPを適用する前に、完全なフォワードパスで中間値を計算する必要があるため、すべての中間値をメモリに格納する必要があるからである。このことは、第六章で説明するように、時として大問題となることがある。
逆積算モードの微分を用いて原始関数の合成鎖の勾配を効率的に計算する方法を説明したが、一般に、合成関数は原始関数の有向非巡回グラフとして記述することができる。幸いなことに、鎖として考えるための戦略は、グラフにも簡単に適用できる。鎖の場合と同様に、関数を評価するために完全なフォアワードパスを計算し、すべての中間値を格納する。次にグラフを逆に走査し、各中間値についてヤコビアン-ベクトル積を適用してを計算する。
チェインルールによる合成では発生しないが、対処すべき追加のケースがある。ある値が複数回使用される「ファンアウト」と、関数が複数の入力を受け取る 「ファンイン」である。ファンインは、関数の各引数に対してヤコビアン-ベクトル積関数を定義することで対処する。変数の再利用であるファンアウトは、を利用する各分岐に対してを計算し、その結果を合計して完全なを得ることによって対処する。この場合、グラフを走査する順番に制約が生じ、すべてのが利用可能でなければ続行できない。図2.3にこの2つのケースを示す。
リバースモードの微分は、様々な量的分野で独自に何度も発見され[9]、機械学習でも何度か発見されている[114]。逆伝搬の(再)発明は Rumelhartら[93]のものが最も有名である。私自身の理解は、PearlmutterとSiskindの仕事(例えば[79])によって大きく形成されたものである。
Reference
[9] Baydin, A. G., Pearlmutter, B. A., & Radul, A. A. (2015). Automatic differentiation in machine learning: a survey. CoRR, abs/1502.05767.
[79] Pearlmutter, B. A. & Siskind, J. M. (2008). Reverse-mode AD in a functional framework: Lambda the ultimate backpropagator. ACM Transactions on Programming Languages and Systems (TOPLAS), 30(2), 7.
[93] Rumelhart, D. E., Hinton, G. E., & Williams, R. J. (1986). Learning repres by back-propagating errors. Nature, 323, 533–536.
[114] Widrow, B. & Lehr, M. A. (1990). 30 years of adaptive neural networks: perceptron, madaline, and backpropagation. Proceedings of the IEEE, 78(9), 1415–1442.
Automatic Differentiation: Deep Learning (DLSS) and Reinforcement Learning (RLSS) Summer School, Montreal 2017
1. Jacobians and the chain rule ―Forward and reverse accumulation
図のように入力を受け取り、を出力する関数を考える。は、、、を原始関数とする合成関数である。また、を入力した際の各原始関数の出力を、それぞれ、、、、とする。
関数の微分は、チェインルールにより、、、、の積で表せる。そして、それらのサイズは、それぞれ、、、、である。入力データのサイズは、大きくなりがちである。その場合、サイズがのも大きくなる。
チェインルールでは計算の順序についての制約はないため、各原始関数のヤコビアンの合成順序は自由である。順積算モードでは、各ヤコビアンを右側から順に合成する。逆積算モードでは、各ヤコビアンを左から合成する。この時、各ヤコビアンの合成では単に行列積を計算することを思い出してほしい(参考:連鎖律(多変数関数の合成関数の微分) | 高校数学の美しい物語)。例えば、順積算モードでは、図のようにまずを計算するが、この計算はサイズとサイズの行列積であり、サイズのとなる。その後、左から新たなヤコビアンを掛けていくことになるが、そのたびに得られる積算ヤコビアンは、行方向のサイズが常にとなる。したがって、通常が大きいために、順積算モードでは、計算結果の途中で得られる積算ヤコビアンの形状が大きくなってしまう。一方、逆積算モードでは、であるために、積算ヤコビアンの列方向のサイズは常にとなり、単なる行ベクトルとなる。ゆえに、逆積算モードの方が計算が楽に、メモリも少なく計算することができ、高速に微分することができる。
順積算モードの微分では、JVPを再帰的に計算することで求めることができ、最初に与えるベクトルをサイズの単位行列、ヤコビアンをにして計算を開始し、順次、計算結果である積算ヤコビアンをベクトルに、左のヤコビアンをヤコビアンとして新たにJVPを計算していくことで最終的に合成関数の勾配を求めることができる。ここで、実際には計算の開始時に単位行列を行ごとのワンホットベクトルに分け、データの各次元ごとにJVPを回計算する。(図2.2上 参照)
逆積算モードの微分では、VJP積を再帰的に計算することで求めることができ、最初に与えるベクトルをサイズの単位行列、ヤコビアンをにして計算を開始し、順次、計算結果である積算ヤコビアンをベクトルに、右のヤコビアンをヤコビアンとして新たなVJPを計算していくことで最終的に合成関数の勾配を求めることができる。ここで、逆方向モードでは積算ヤコビアンが常に行ベクトルになるために、順積算モードのようにの各次元ごとにヤコビアンを計算することはせず、一度VJPを適用するだけでよい。(図2.2下 参照)したがって、ナイーブに考えれば、順積算モードと比べて計算量はで済む。
ある値が複数回使用される「ファンアウト」と、関数が複数の入力を受け取る 「ファンイン」では、連鎖が分岐してしまい、計算グラフが分岐してしまうため、特別な対応が必要となる。ファンインは、単に各引数に対してJVPを定義して引数ごとにヤコビアンを求め、単にスタックする。ある変数の再利用であるファンアウトは、を利用する各分岐に対してを計算し、その結果を合計して完全なを得ることによって対処する。ここでは、単純な線形変換を行うを考える。そのVJPを考える場合は、単純に(転置)ベクトルの各成分を足し合わせることで求めることができる。