RetNetを完全に理解する①:Retentionメカニズム
Transformerの後継と称されるRetNetの以下の論文中にて、特に二章で解説されるRetNetのアーキテクチャについて、行間を埋めながら解説する。
*自分の理解をもとに書いているので、違っているようでしたらコメントください。
Retentive Network
RetNetは、Transformerと同じように、個の同じブロックを積み重ねる形状をしており、それぞれのブロックが、Multi-Scale RetentionとFeed Forward Networkモジュールを持ち、それらの前後でそれぞれRMS NormとResidual Connectionが行われる。 つまり、RetNetの各ブロックは、基本的にはTransformerブロックにおいて、Pre NormにしてLayer Normの代わりにRMS Normを行い、Multi-head AttentionをMulti-Scale Retentionで置き換えたもので、さらに相対位置埋め込みを行うので位置埋め込みを各層の入力前に行わないようなモデルである。 (めっちゃざっくりと言ってるので細かいことは論文を読んでください。)
RetNetは、入力ベクトルが単語埋め込みによりに埋め込まれたとき、文脈づけられたベクトルを計算する。
Retention
Retentionメカニズムは、再帰形式と並列形式という二つの形式を持ち、並列に学習を行いながら、再帰的に推論を行うことが出来る。
入力が与えられたとき、学習可能な行列を用いて、を求める。 時刻におけるトークンから得られる時刻におけるバリュー(の行目のベクトル)と、一つ前の時刻における状態を使って、変換後の系列の番目のトークンを求める写像 をモデリングする系列変換問題を考える。
ここで、クエリーとキーは、バリューと同じように、学習可能な行列を用いて、で求められる。 はそれぞれの行目のベクトルであり、それぞれトークンに対応するクエリとキーである。
式(1)で、二つ目の式がどうして出てきたのかが分かりづらい。
これを理解するためにまず、一つ目の式から以下のようにを順次求めていく。
すると、実は一つ目の漸化式からの一般式が以下のように求められることが分かるだろう。
この式においてとすれば、が導かれる。
次に、正方行列は、を正則行列とし、としてと対角化することが出来るものとしている。 対角化を復習したい場合はwikiなどを参照されたい。 ここで、としているが、であり、以下のように表される。 ここら辺についてはgithubのissueでも議論している。
すると、と表され、は、以下のように求められる。
すると、式(1)の二つ目の式は、以下のように表せる。
式中で、は、、のような形で学習可能パラメータと掛けられるため、別のパラメータとせずに吸収されまとめて学習されるとする。
すると、以下のように表される。
ここで、は対角行列より、であることを用いた。 、はxPosとして知られている位置埋め込みである。(ちなみにxPosの著者とRetNetの著者はほぼ同じ。xPosを開発した精華大学のチームがMicrosoft Researchに引き抜かれたようだ。)
さらに簡素化してをスカラーとすることで、以下が導かれる。
ここで、の定義は以下である。
この定義では、であるのに、であるのでわかりづらい。 ここで、 ]と新たに定義し、とすると、は以下のように表せる。
すると、は以下のように表すことができる。
Parallel Representation
さらに、全時刻におけるをまとめて新たにを以下のように定義する。
そして、を、各時刻におけるおよびをそれぞれ含むように定義し直すと、それぞれ以下のようになる。
さらに、入力系列全体を変換後のベクトルであるは、とすると、以下のように表される。
ここで、
とした。
式(5)がRetentionの並列形式であり、学習時にはこの式に従って計算される。
Recurrent Representation
式(4)をを用いて表すと、以下のようになる。
式(1)と見比べれば以下が成り立つのが分かるだろう。
これがRetNetの再帰形式であり、推論時はこの式によって再帰的に次トークンが予測される。
Attention vs Retention
Attentionは以下で表される。
ここで、はAttentionマスクであり、簡単のため、以下のように表されるデコーダマスクであるとする。
すると、は、以下のように表せる。
従って、
のある時刻における出力をとする。
従って、は以下のように表せる。
ここで、クエリはインデックスのみを持ち、時刻の出力を計算するのに同時刻のクエリのみを必要とするのに対し、はインデックスを持ち、時刻での出力を計算するのにそれまでの全時刻のキー、バリューが必要になる。 このため、キー、バリューは推論ごとにキャッシュされる。
Retentionの場合は、以下で表されるのであった。
であり、と一見して似た形をしているが、を作用させているかどうかで、が大きく異なっているのが見て取れるだろう。 Attentionでは、まずでを求めた後、さらにでを計算する。このようにクエリーは現時点でのを使えばよいが、これまでの値すべてのバリューとキーとそれぞれ計算を行う必要があるため、推論時の計算量は全体でとなってしまう。
しかしながら、Retentionでは、以下のように時刻までのキーとバリューを用いる計算を分け、状態として保持することができ、各時刻での計算量は全体としてで済む。
以上。次はそのうち相対位置埋め込みxPosの理論と実装について書きます。