izmyonの日記

奈良の山奥で研究にいそしむ大学院生の学習記録。

RetNetを完全に理解する①:Retentionメカニズム

Transformerの後継と称されるRetNetの以下の論文中にて、特に二章で解説されるRetNetのアーキテクチャについて、行間を埋めながら解説する。

arxiv.org

*自分の理解をもとに書いているので、違っているようでしたらコメントください。

Retentive Network

RetNetは、Transformerと同じように、 L個の同じブロックを積み重ねる形状をしており、それぞれのブロックが、Multi-Scale RetentionとFeed Forward Networkモジュールを持ち、それらの前後でそれぞれRMS NormとResidual Connectionが行われる。 つまり、RetNetの各ブロックは、基本的にはTransformerブロックにおいて、Pre NormにしてLayer Normの代わりにRMS Normを行い、Multi-head AttentionをMulti-Scale Retentionで置き換えたもので、さらに相対位置埋め込みを行うので位置埋め込みを各層の入力前に行わないようなモデルである。 (めっちゃざっくりと言ってるので細かいことは論文を読んでください。)

retnet

RetNetは、入力ベクトル \{ \textbf{x}_i \}_{i=1}^{|x|} が単語埋め込みにより X^0 = [ \textbf{x}_1, \cdots, \textbf{x}_{|x|} ] \in \mathbb{R}^{|x| \times d_{model} } に埋め込まれたとき、文脈づけられたベクトル X^l = \text{RetNet}_l ( X^{l-1} ) , l \in [1, L ] を計算する。

Retention

Retentionメカニズムは、再帰形式と並列形式という二つの形式を持ち、並列に学習を行いながら、再帰的に推論を行うことが出来る。

入力 X \in \mathbb{R}^{|x| \times d_{model} } が与えられたとき、学習可能な行列 W_V \in \mathbb{R}^{d \times d} を用いて、 V =  X W_Vを求める。 時刻 nにおけるトーク X_nから得られる時刻 nにおけるバリュー V_n = X_n \cdot w_V V n行目のベクトル)と、一つ前の時刻 n-1における状態 s_{n-1}を使って、変換後の系列の n番目のトーク o_nを求める写像 v(n) \mapsto o(n)モデリングする系列変換問題を考える。

RetNetでは、この写像は以下の再帰形式で表せる。

 \displaystyle \begin{align}
s_n &= A s_{n-1} + K_n^T V_n \quad &A \in \mathbb{R}^{d \times d}, K_n \in \mathbb{R}^{1 \times d} \\
o_n &= Q_n s_n = \sum_{m=1}^{n} Q_n A^{n-m} K^T_m V_m &Q_n \in \mathbb{R}^{1 \times d} \tag{1}
\end{align}

ここで、クエリーとキーは、バリューと同じように、学習可能な行列 W_Q, W_K \in \mathbb{R}^{d \times d} を用いて、 Q =  X W_Q, K = X W_K \tag{2}で求められる。  Q_n, K_n \in \mathbb{R}^{1 \times d}はそれぞれ Q, K \in \mathbb{R}^{|x| \times d_{model} } n行目のベクトルであり、それぞれトーク X_nに対応するクエリとキーである。

式(1)で、二つ目の式がどうして出てきたのかが分かりづらい。

これを理解するためにまず、一つ目の式から以下のように s_1, s_2, s_3を順次求めていく。

 \displaystyle \begin{align}
s_1 &= A s_0 + K_1^T v_1 \\
s_2 &= A s_1 + K_2^T v_2 = A (A s_0 + K_1^T v_1) + K_2^T v_2 = A^2 s_0 + A K_1^T v_1 + K_2^T v_2 \\
s_3 &= A s_2 + K_3^T v_3 = A (A^2 s_0 + A K_1^T v_1 + K_2^T v_2 ) +  K_3^T v_3 \\
      &= A^3 s_0 + ( A^2 K_1^T v_1 + A K_2^T v_2 + K_3^T v_3 ) = A^3 s_0 + \sum_{m=1}^{3} A^{3-m} K^T_m V_m
\end{align}

すると、実は一つ目の漸化式から s_nの一般式が以下のように求められることが分かるだろう。

 \displaystyle
s_n = A s_{n-1} + K_n^T V_n = A^n s_0 + \sum_{m=1}^{n} A^{n-m} K^T_m V_m

この式において s_0 = 0とすれば、 o_n = Q_n s_n = \sum_{m=1}^{n} Q_n A^{n-m} K^T_m V_mが導かれる。

次に、正方行列 Aは、 \Lambda正則行列とし、 \gamma, \theta \in \mathbb{R}^dとして  \Lambda^{-1} A \Lambda = \gamma e^{i \theta} と対角化することが出来るものとしている。 対角化を復習したい場合はwikiなどを参照されたい。 ここで、  \theta \in \mathbb{R}^dとしているが、   \gamma e^{i \theta} \in  \mathbb{R}^{d \times d} であり、以下のように表される。 ここら辺についてはgithubのissueでも議論している。

 \displaystyle \begin{align}
( \gamma e^{i \theta} ) = 

\begin{pmatrix}
\gamma_1 e^{i \theta_1} & 0 & \cdots & 0 \\
0 & \gamma_2 e^{i \theta_2} & \cdots & 0 \\
\vdots & \vdots &  \ddots & \vdots \\
0 & 0 & \cdots & \gamma_d e^{i \theta_d} 
\end{pmatrix}

\end{align}

すると、 A = \Lambda ( \gamma e^{i \theta} ) \Lambda^{-1} と表され、 A^{n-m}は、以下のように求められる。

 \displaystyle \begin{align}
A^{n-m} &= ( \Lambda ( \gamma e^{i \theta} ) \Lambda^{-1}  ) ( \Lambda ( \gamma e^{i \theta} ) \Lambda^{-1}  ) \cdots  ( \Lambda ( \gamma e^{i \theta} ) \Lambda^{-1}  ) \\
&=  \Lambda ( \gamma e^{i \theta} ) ( \Lambda^{-1} \Lambda ) ( \gamma e^{i \theta} ) ( \Lambda^{-1}  \cdots  \Lambda ) ( \gamma e^{i \theta} ) \Lambda^{-1}  \\
&= \Lambda ( \gamma e^{i \theta} )^{n-m} \Lambda^{-1}
\end{align}

すると、式(1)の二つ目の式は、以下のように表せる。

 \displaystyle \begin{align}
o_n &= \sum_{m=1}^{n} Q_n  \Lambda \left( \gamma e^{i \theta} \right)^{n-m} \Lambda^{-1} K^T_m V_m
=  \sum_{m=1}^{n}  (X_n w_Q)  \Lambda \left( \gamma e^{i \theta} \right)^{n-m} \Lambda^{-1} (w_K^T X_m^T) V_m \\
&= \sum_{m=1}^{n}  X_n (w_Q  \Lambda) \left( \gamma e^{i \theta} \right)^{n-m} (w_K (\Lambda^{-1})^T )^T X_m^T V_m 
\end{align}

式中で、 \Lambdaは、 w_Q  \Lambda w_K (\Lambda^{-1})^Tのような形で学習可能パラメータ w_Q, w_Kと掛けられるため、別のパラメータとせず w_Q, w_Kに吸収されまとめて学習されるとする。

すると、以下のように表される。

 \displaystyle \begin{align}
o_n &= \sum_{m=1}^{n}  Q_n \left( \gamma e^{i \theta} \right)^{n-m} K_m^T  V_m \\
 &= \sum_{m=1}^{n}  Q_n \left( \gamma e^{i n \theta} \right) \left( \gamma e^{- i m \theta} \right) K_m^T  V_m \\
 &= \sum_{m=1}^{n}  (Q_n \left( \gamma e^{i n \theta} \right) )  \left( \gamma e^{-i m  \theta} \right)^T K_m^T  V_m \\
 &= \sum_{m=1}^{n}  (Q_n \left( \gamma e^{i n \theta} \right) )  \left(  K_m (\gamma e^{ -i m \theta} ) \right)^T  V_m 
\end{align}

ここで、  ( \gamma e^{i \theta} )  は対角行列より、  ( \gamma e^{i \theta} )  =   ( \gamma e^{i \theta} )^Tであることを用いた。  Q_n ( \gamma e^{i n \theta} ) K_m ( \gamma e^{- i m \theta} )はxPosとして知られている位置埋め込みである。(ちなみにxPosの著者とRetNetの著者はほぼ同じ。xPosを開発した精華大学のチームがMicrosoft Researchに引き抜かれたようだ。)

さらに簡素化して \gammaスカラーとすることで、以下が導かれる。

 \displaystyle 
o_n = \sum_{m=1}^{n} \gamma^{n-m} (Q_n e^{in \theta}) ( K_m e^{im \theta} )^{\dagger} V_m \tag{3}

ここで、 e^{i n \theta}の定義は以下である。

 \displaystyle
e^{i n \theta} = 
\begin{pmatrix}
e^{i n \theta_1} & 0 & \cdots & 0 \\
0 & e^{i n \theta_2} & \cdots & 0 \\
\vdots & \vdots &  \ddots & \vdots \\
0 & 0 & \cdots &e^{i n \theta_d} 
\end{pmatrix}

この定義では、 \theta \in \mathbb{R}であるのに、 e^{i n \theta} \in \mathbb{R}^{d \times d}であるのでわかりづらい。 ここで、 e^{i n \theta} =  [ e^{i n \theta_1} , \ldots, e^{i n \theta_d}  ]と新たに定義し、 Q_n = [ q_1, \ldots, q_d ] とすると、 Q_n e^{i n \theta}は以下のように表せる。

 \displaystyle \begin{align}
Q_n e^{i n \theta} &= [ q_1, \ldots, q_d ] 

\begin{pmatrix}
e^{i n \theta_1} & 0 & \cdots & 0 \\
0 & e^{i n \theta_2} & \cdots & 0 \\
\vdots & \vdots &  \ddots & \vdots \\
0 & 0 & \cdots &e^{i n \theta_d} 
\end{pmatrix}
=  [ q_1 e^{i n \theta_1} , \ldots, q_d e^{i n \theta_d}  ] \\
&= Q_n \odot [ e^{i n \theta_1} , \ldots, e^{i n \theta_d}  ] = Q_n \odot e^{i n \theta}
\end{align}

すると、 o_nは以下のように表すことができる。

 \displaystyle 
o_n = \sum_{m=1}^{n} \gamma^{n-m} (Q_n \odot e^{in \theta}) ( K_m \odot e^{im \theta} )^{\dagger} V_m \tag{4}

Parallel Representation

さらに、全時刻における e^{i n \theta }をまとめて新たに \Theta \in \mathbb{R}^{|x| \times d_{model} }を以下のように定義する。

 \displaystyle
\Theta = 

\begin{pmatrix}
e^{i  \theta} \\
\vdots \\
e^{i |x| \theta} 
\end{pmatrix}

そして、 Q, K \in \mathbb{R}^{|x| \times d_{model} }を、各時刻における Q_n \odot e^{i n \theta}および K_n \odot e^{- i n \theta}をそれぞれ含むように定義し直すと、それぞれ以下のようになる。

 \displaystyle
Q = 

\begin{pmatrix}
Q_1 \odot e^{i \theta} \\
\vdots \\
Q_{|x|} \odot e^{i |x| \theta} 
\end{pmatrix}

= 

\begin{pmatrix}
Q_1 \\
\vdots \\
Q_{|x|} 
\end{pmatrix}

\odot \Theta \\

= (X W_Q) \odot  \Theta
 \displaystyle
K = 

\begin{pmatrix}
K_1 \odot e^{- i \theta} \\
\vdots \\
K_{|x|} \odot e^{- i |x| \theta} 
\end{pmatrix}

= 

\begin{pmatrix}
K_1 \\
\vdots \\
K_{|x|} 
\end{pmatrix}

\odot \bar{\Theta} \\

= (X W_K) \odot  \bar{\Theta}

さらに、入力系列 X全体を変換後のベクトルである \text{Retention} (X)は、 Q_n^{\prime} = Q_n \odot e^{in \theta},   K_m^{\prime} = K_m \odot e^{im \theta} とすると、以下のように表される。

 \displaystyle
\text{Retention} (X) = 

\begin{pmatrix}
o_1  \\
\vdots \\
o_|x|
\end{pmatrix} \\

= 

\begin{pmatrix}
\sum_{m=1}^{1} \gamma^{1-m} Q_1^{\prime} K_m^{\prime \dagger} V_m  \\
\vdots \\
\sum_{m=1}^{|x|} \gamma^{|x|-m} Q_{|x|}^{\prime}  K_m^{\prime \dagger} V_m
\end{pmatrix}
 \displaystyle

= 

\left(

\begin{pmatrix}
Q_1^{\prime} K_1^{\prime \dagger}, \ldots, Q_1^{\prime} K_{|x|}^{\prime \dagger} \\
\vdots \\
Q_{|x|}^{\prime} K_1^{\prime \dagger}, \ldots, Q_{|x|}^{\prime} K_{|x|}^{\prime \dagger}
\end{pmatrix}

\odot

\begin{pmatrix}
\gamma^{1-1}, 0, 0, \ldots, 0 \\
\gamma^{2-1}, \gamma^{2-2}, 0, \ldots, 0 \\
\vdots \\
\gamma^{|x|-1}, \ldots, \gamma^{|x|-|x|}
\end{pmatrix} 

\right)

\begin{pmatrix}
V_1 \\
\vdots \\
V_{|x|}
\end{pmatrix} \\
 \displaystyle

= (QK^T \odot D) V

\tag{5}

ここで、

 \displaystyle

D_{nm}

= 

\left\{
\begin{array}{ll}
\gamma^{n-m} & (n \geq m)\\
0 & (n < m)
\end{array}
\right.

とした。

式(5)がRetentionの並列形式であり、学習時にはこの式に従って計算される。

Recurrent Representation

式(4)を Q_n^{\prime}, K_m^{\prime} を用いて表すと、以下のようになる。

 \displaystyle 
o_n = \sum_{m=1}^{n} \gamma^{n-m} Q_n^{\prime} K_m^{\prime \dagger} V_m

式(1)と見比べれば以下が成り立つのが分かるだろう。

 \displaystyle \begin{align}
s_n &= \gamma s_{n-1} + K_n^{\prime T} V_n \\
\text{Retention} (X_n) &= Q_n^{\prime} s_n = \sum_{m=1}^{n} Q_n^{\prime} {\gamma}^{n-m} K^{\prime T}_m V_m, \quad n = 1, \ldots, |x| \tag{6}
\end{align}

これがRetNetの再帰形式であり、推論時はこの式によって再帰的に次トークンが予測される。

Attention vs Retention

Attentionは以下で表される。

 \displaystyle

\text{Attention} (X) = \text{softmax}(QK^T  \odot M )  V

\tag{7}

ここで、 MはAttentionマスクであり、簡単のため、以下のように表されるデコーダマスクであるとする。

 \displaystyle

(M)_{n,m}

= 

\left\{
\begin{array}{ll}
1 & (n \geq m)\\
0 & (n < m)
\end{array}
\right.

すると、 \text{softmax}(QK^T  \odot M ) は、以下のように表せる。

 \displaystyle

\text{softmax} (QK^T \odot M)

=

\begin{pmatrix}
\frac{ exp ( Q_1 K_1^T ) }{ \sum_{m=1}^{1} exp ( Q_1 K_m^T ) } , 0, \ldots,  0, 0  \\
\frac{ exp ( Q_2 K_1^T ) }{ \sum_{m=1}^{2} exp ( Q_2 K_m^T ) } , \frac{ exp ( Q_2 K_2^T ) }{ \sum_{m=1}^{2} exp ( Q_2 K_m^T ) }, \ldots,  0  \\
\vdots \\
\frac{ exp ( Q_{|x|}  K_1^T ) }{ \sum_{m=1}^{|x|} exp ( Q_{|x|} K_m^T ) } , \frac{ exp ( Q_{|x|}  K_2^T ) }{ \sum_{m=1}^{|x|} exp ( Q_{|x|} K_m^T ) },  \ldots,  \frac{ exp ( Q_{|x|} K_{|x|}^T ) }{ \sum_{m=1}^{|x|} exp ( Q_{|x|} K_m^T ) }  \\
\end{pmatrix}

従って、

 \displaystyle

(\text{softmax} (QK^T) \odot M )_{n,m} =  

\left\{
\begin{array}{ll}
\frac{ exp ( Q_{n}  K_m^T ) }{ \sum_{m=1}^{n} exp ( Q_{n} K_m^T ) } & (n \geq m)\\
0 & (n < m)
\end{array}
\right.

 \text{Attention} (X)のある時刻 nにおける出力を o_nとする。

 \displaystyle

\text{Attention} (X) 

=

\begin{pmatrix}
o_1  \\
\vdots \\
o_{|x|}  \\
\end{pmatrix} 

= 

\begin{pmatrix}
\sum_{m=1}^{1}  \frac{ exp ( Q_1 K_m^T ) }{ \sum_{m=1}^{1} exp ( Q_1 K_m^T ) } V_m  \\
\vdots \\
\sum_{m=1}^{|x|}  \frac{ exp ( Q_{|x|} K_m^T ) }{ \sum_{m=1}^{|x|} exp ( Q_{|x|} K_m^T ) } V_m  \\
\end{pmatrix}

従って、 o_nは以下のように表せる。

 \displaystyle

o_n =  \sum_{m=1}^{n}  \frac{ exp ( Q_n K_m^T ) }{ \sum_{m=1}^{n} exp ( Q_n K_m^T ) } V_m 

\tag{8}

ここで、クエリはインデックス nのみを持ち、時刻 nの出力を計算するのに同時刻のクエリのみを必要とするのに対し、 K, Vはインデックス mを持ち、時刻 nでの出力を計算するのにそれまでの全時刻のキー、バリューが必要になる。 このため、キー、バリューは推論ごとにキャッシュされる。

Retentionの場合は、以下で表されるのであった。

 \displaystyle 
o_n = Q_n^{\prime} \sum_{m=1}^{n} \gamma^{n-m} K_m^{\prime \dagger} V_m

 Retention(X) = (QK^T \odot D) Vであり、 Attention(X) = \text{Attention} (X) = \text{softmax}(QK^T  \odot M )  Vと一見して似た形をしているが、 \text{softmax}を作用させているかどうかで、 o_nが大きく異なっているのが見て取れるだろう。 Attentionでは、まず O (n)  (\text{softmax} (QK^T) \odot M )_{n,m} = \frac{ exp ( Q_n K_m^T ) }{ \sum_{m=1}^{n} exp ( Q_n K_m^T ) }を求めた後、さらに O(n) o_n = \sum_{m=1}^{n}  \frac{ exp ( Q_n K_m^T ) }{ \sum_{m=1}^{n} exp ( Q_n K_m^T ) } V_mを計算する。このようにクエリーは現時点での Q_nを使えばよいが、これまでの値すべてのバリューとキーとそれぞれ計算を行う必要があるため、推論時の計算量は全体で O(n)となってしまう。

しかしながら、Retentionでは、以下のように時刻 n-1までのキーとバリューを用いる計算を分け、状態として保持することができ、各時刻での計算量は全体として O(1)で済む。

 \displaystyle 
o_n = Q_n^{\prime} \sum_{m=1}^{n} \gamma^{n-m} K_m^{\prime \dagger} V_m =  Q_n^{\prime} \left( \gamma \left( \sum_{m=1}^{n-1} \gamma^{n-1-m} K_m^{\prime \dagger}  V_m \right) + K_n^{\prime \dagger} V_n  \right) \\
=  Q_n^{\prime} \left( \gamma s_{n-1} + K_n^{\prime \dagger} V_n  \right)

以上。次はそのうち相対位置埋め込みxPosの理論と実装について書きます。