izmyonの日記

奈良の山奥で研究にいそしむ大学院生の学習記録。

Understanding RetNet①: Theory of Retention

In a paper of RetNet, regarded as a successor to Transformer, particularly in Chapter Two, the architecture of RetNet is explained.

However, the formula in the paper is a little confusing. In this post, the details of formula is explained while filling in the gaps.

arxiv.org

*This is based on my understanding, so please comment if something seems incorrect.

Retentive Network

RetNet, similar to Transformer, is structured with a stack of  L identical blocks. Each block contains a Multi-Scale Retention and Feed Forward Network module, with RMS Norm and Residual Connection performed before and after these modules. Essentially, each block of RetNet is akin to a Transformer block, but with RMS Norm replacing Layer Norm, Multi-Scale Retention substituted by Multi-head Attention, and also performs relative position embedding, hence not requiring absolute position embedding. (Please refer to the paper for more detailed information.)

RetNet computes contextually embedded vectors  X^l = \text{RetNet}_l ( X^{l-1} ) , l \in [1, L ] when input vectors  { \textbf{x} }_{i=1}^{|x|} are embedded into  X^0 = [ \textbf{x}_1, \cdots, \textbf{x}_{|x|} ] \in \mathbb{R}^{|x| \times d_{model} }.

Retention

The Retention mechanism has two forms: recursive and parallel, allowing for parallel training while making recursive inferences.

Given  X \in \mathbb{R}^{|x| \times d_{model} }, a trainable matrix  W_V \in \mathbb{R}^{d \times d} is used to compute  V = X W_V. For a token  X_n at time  n, the value  V_n = X_n \cdot w_V (the  nth row of  V) and the state at the previous time  n-1,  s_{n-1}, are used to model a mapping  v(n) \mapsto o(n) that computes the  nth token of the transformed sequence.

In RetNet, this mapping is represented recursively as follows:

 \displaystyle \begin{align}
s_n &= A s_{n-1} + K_n^T V_n \quad &A \in \mathbb{R}^{d \times d}, K_n \in \mathbb{R}^{1 \times d} \\
o_n &= Q_n s_n = \sum_{m=1}^{n} Q_n A^{n-m} K^T_m V_m &Q_n \in \mathbb{R}^{1 \times d} \tag{1}
\end{align}

Here, query and key, like value, are determined using trainable matrices  W_Q, W_K \in \mathbb{R}^{d \times d}, with follows.  Q = X W_Q, K = X W_K \tag{2}

 Q_n, K_n \in \mathbb{R}^{1 \times d} are respectively the  nth row vectors of  Q, K \in \mathbb{R}^{|x| \times d_{model} }, corresponding to the query and key for token  X_n.

The second equation of (1) may be less intuitive. To understand it, let's sequentially compute  s_1, s_2, s_3 from the first equation:

 \displaystyle \begin{align}
s_1 &= A s_0 + K_1^T v_1 \\
s_2 &= A s_1 + K_2^T v_2 = A (A s_0 + K_1^T v_1) + K_2^T v_2 = A^2 s_0 + A K_1^T v_1 + K_2^T v_2 \\
s_3 &= A s_2 + K_3^T v_3 = A (A^2 s_0 + A K_1^T v_1 + K_2^T v_2 ) +  K_3^T v_3 \\
      &= A^3 s_0 + ( A^2 K_1^T v_1 + A K_2^T v_2 + K_3^T v_3 ) = A^3 s_0 + \sum_{m=1}^{3} A^{3-m} K^T_m V_m
\end{align}

From here, the general formula for  s_n can be derived from the first recursive formula:

 \displaystyle
s_n = A s_{n-1} + K_n^T V_n = A^n s_0 + \sum_{m=1}^{n} A^{n-m} K^T_m V_m

Setting  s_0 = 0, we derive  o_n = Q_n s_n = \sum_{m=1}^{n} Q_n A^{n-m} K^T_m V_m.

Next, the square matrix  A is assumed to be diagonalizable with a regular matrix  \Lambda and  \gamma, \theta \in \mathbb{R}^d, such that  \Lambda^{-1} A \Lambda = \gamma e^{i \theta} . For those who want to review diagonalization, see wiki. Here, although  \theta \in \mathbb{R}^d,  \gamma e^{i \theta} \in \mathbb{R}^{d \times d} is represented as follows. This topic is also discussed in github's issue.

 \displaystyle \begin{align}
( \gamma e^{i \theta} ) = 

\begin{pmatrix}
\gamma_1 e^{i \theta_1} & 0 & \cdots & 0 \\
0 & \gamma_2 e^{i \theta_2} & \cdots & 0 \\
\vdots & \vdots &  \ddots & \vdots \\
0 & 0 & \cdots & \gamma_d e^{i \theta_d} 
\end{pmatrix}

\end{align}

Consequently,  A = \Lambda ( \gamma e^{i \theta} ) \Lambda^{-1} , and  A^{n-m} can be computed as:

 \displaystyle \begin{align}
A^{n-m} &= ( \Lambda ( \gamma e^{i \theta} ) \Lambda^{-1}  ) ( \Lambda ( \gamma e^{i \theta} ) \Lambda^{-1}  ) \cdots  ( \Lambda ( \gamma e^{i \theta} ) \Lambda^{-1}  ) \\
&=  \Lambda ( \gamma e^{i \theta} ) ( \Lambda^{-1} \Lambda ) ( \gamma e^{i \theta} ) ( \Lambda^{-1}  \cdots  \Lambda ) ( \gamma e^{i \theta} ) \Lambda^{-1}  \\
&= \Lambda ( \gamma e^{i \theta} )^{n-m} \Lambda^{-1}
\end{align}

Thus, the second equation of (1) can be expressed as:

 \displaystyle \begin{align}
o_n &= \sum_{m=1}^{n} Q_n  \Lambda \left( \gamma e^{i \theta} \right)^{n-m} \Lambda^{-1} K^T_m V_m
=  \sum_{m=1}^{n}  (X_n w_Q)  \Lambda \left( \gamma e^{i \theta} \right)^{n-m} \Lambda^{-1} (w_K^T X_m^T) V_m \\
&= \sum_{m=1}^{n}  X_n (w_Q  \Lambda) \left( \gamma e^{i \theta} \right)^{n-m} (w_K (\Lambda^{-1})^T )^T X_m^T V_m 
\end{align}

In the formula,  \Lambda is multiplied by the trainable parameters  w_Q, w_K, so instead of being separate parameters, they are absorbed into  w_Q, w_K and learned together.

Therefore, it is represented as:

 \displaystyle \begin{align}
o_n &= \sum_{m=1}^{n}  Q_n \left( \gamma e^{i \theta} \right)^{n-m} K_m^T  V_m \\
 &= \sum_{m=1}^{n}  Q_n \left( \gamma e^{i n \theta} \right) \left( \gamma e^{- i m \theta} \right) K_m^T  V_m \\
 &= \sum_{m=1}^{n}  (Q_n \left( \gamma e^{i n \theta} \right) )  \left( \gamma e^{-i m  \theta} \right)^T K_m^T  V_m \\
 &= \sum_{m=1}^{n}  (Q_n \left( \gamma e^{i n \theta} \right) )  \left(  K_m (\gamma e^{ -i m \theta} ) \right)^T  V_m 
\end{align}

Here,  ( \gamma e^{i \theta} ) being a diagonal matrix,  ( \gamma e^{i \theta} ) = ( \gamma e^{i \theta} )^T is used.  Q_n ( \gamma e^{i n \theta} ),  K_m ( \gamma e^{- i m \theta} ) are known as xPos embeddings. (Interestingly, the authors of xPos and RetNet are almost the same. It seems the team from Seika University that developed xPos was recruited by Microsoft Research.)

For further simplification, treating  \gamma as a scalar leads to:

 \displaystyle 
o_n = \sum_{m=1}^{n} \gamma^{n-m} (Q_n e^{in \theta}) ( K_m e^{im \theta} )^{\dagger} V_m \tag{3}

The definition of  e^{i n \theta} is as follows:

 \displaystyle
e^{i n \theta} = 
\begin{pmatrix}
e^{i n \theta_1} & 0 & \cdots & 0 \\
0 & e^{i n \theta_2} & \cdots & 0 \\
\vdots & \vdots &  \ddots & \vdots \\
0 & 0 & \cdots &e^{i n \theta_d} 
\end{pmatrix}

In this definition, although  \theta \in \mathbb{R},  e^{i n \theta} \in \mathbb{R}^{d \times d} can be confusing. Therefore, redefining

and setting  Q_n = [ q_1, \ldots, q_d ],  Q_n e^{i n \theta} can be expressed as:

Hence,  o_n can be represented as:

 \displaystyle 
o_n = \sum_{m=1}^{n} \gamma^{n-m} (Q_n \odot e^{in \theta}) ( K_m \odot e^{im \theta} )^{\dagger} V_m \tag{4}

Parallel Representation

Furthermore, defining  \Theta \in \mathbb{R}^{|x| \times d_{model}} as a compilation of  e^{i n \theta} for all time steps:

 \displaystyle
\Theta = 

\begin{pmatrix}
e^{i  \theta} \\
\vdots \\
e^{i |x| \theta} 
\end{pmatrix}

Then, redefining  Q, K \in \mathbb{R}^{|x| \times d_{model}} to include  Q_n \odot e^{i n \theta} and  K_n \odot e^{- i n \theta} for each time step respectively, they become:

 \displaystyle
Q = 

\begin{pmatrix}
Q_1 \odot e^{i \theta} \\
\vdots \\
Q_{|x|} \odot e^{i |x| \theta} 
\end{pmatrix}

= 

\begin{pmatrix}
Q_1 \\
\vdots \\
Q_{|x|} 
\end{pmatrix}

\odot \Theta \\

= (X W_Q) \odot  \Theta
 \displaystyle
K = 

\begin{pmatrix}
K_1 \odot e^{- i \theta} \\
\vdots \\
K_{|x|} \odot e^{- i |x| \theta} 
\end{pmatrix}

= 

\begin{pmatrix}
K_1 \\
\vdots \\
K_{|x|} 
\end{pmatrix}

\odot \bar{\Theta} \\

= (X W_K) \odot  \bar{\Theta}

The transformed vector for the entire input sequence  X,  \text{Retention} (X), is expressed with  Q_n^{\prime} = Q_n \odot e^{in \theta}, K_m^{\prime} = K_m \odot e^{im \theta} as follows:

 \displaystyle
\text{Retention} (X) = 

\begin{pmatrix}
o_1  \\
\vdots \\
o_|x|
\end{pmatrix} \\

= 

\begin{pmatrix}
\sum_{m=1}^{1} \gamma^{1-m} Q_1^{\prime} K_m^{\prime \dagger} V_m  \\
\vdots \\
\sum_{m=1}^{|x|} \gamma^{|x|-m} Q_{|x|}^{\prime}  K_m^{\prime \dagger} V_m
\end{pmatrix}
 \displaystyle

= 

\left(

\begin{pmatrix}
Q_1^{\prime} K_1^{\prime \dagger}, \ldots, Q_1^{\prime} K_{|x|}^{\prime \dagger} \\
\vdots \\
Q_{|x|}^{\prime} K_1^{\prime \dagger}, \ldots, Q_{|x|}^{\prime} K_{|x|}^{\prime \dagger}
\end{pmatrix}

\odot

\begin{pmatrix}
\gamma^{1-1}, 0, 0, \ldots, 0 \\
\gamma^{2-1}, \gamma^{2-2}, 0, \ldots, 0 \\
\vdots \\
\gamma^{|x|-1}, \ldots, \gamma^{|x|-|x|}
\end{pmatrix} 

\right)

\begin{pmatrix}
V_1 \\
\vdots \\
V_{|x|}
\end{pmatrix} \\
 \displaystyle

= (QK^T \odot D) V

\tag{5}

ここで、

 \displaystyle

D_{nm}

= 

\left\{
\begin{array}{ll}
\gamma^{n-m} & (n \geq m)\\
0 & (n < m)
\end{array}
\right.

This form (5) represents Retention in parallel and is used for calculations during training.

Recurrent Representation

Expressing equation (4) using  Q_n^{\prime}, K_m^{\prime} results in:

 \displaystyle 
o_n = \sum_{m=1}^{n} \gamma^{n-m} Q_n^{\prime} K_m^{\prime \dagger} V_m

Comparing this with equation (1), it becomes clear that the following holds:

 \displaystyle \begin{align}
s_n &= \gamma s_{n-1} + K_n^{\prime T} V_n \\
\text{Retention} (X_n) &= Q_n^{\prime} s_n = \sum_{m=1}^{n} Q_n^{\prime} {\gamma}^{n-m} K^{\prime T}_m V_m, \quad n = 1, \ldots, |x| \tag{6}
\end{align}

This is the recurrent form of RetNet, and during inference, the next token is predicted recursively using this formula.

Attention vs Retention

Attention is represented as:

 \displaystyle
\text{Attention} (X) = \text{softmax}(QK^T \odot M ) V

\tag{7}

Here,  M is an Attention mask, for simplicity, assumed to be a decoder mask represented as:

 \displaystyle

(M)_{n,m}

= 

\left\{
\begin{array}{ll}
1 & (n \geq m)\\
0 & (n < m)
\end{array}
\right.

Thus,  \text{softmax}(QK^T \odot M ) can be expressed as:

 \displaystyle

\text{softmax} (QK^T \odot M)

=

\begin{pmatrix}
\frac{ exp ( Q_1 K_1^T ) }{ \sum_{m=1}^{1} exp ( Q_1 K_m^T ) } , 0, \ldots,  0, 0  \\
\frac{ exp ( Q_2 K_1^T ) }{ \sum_{m=1}^{2} exp ( Q_2 K_m^T ) } , \frac{ exp ( Q_2 K_2^T ) }{ \sum_{m=1}^{2} exp ( Q_2 K_m^T ) }, \ldots,  0  \\
\vdots \\
\frac{ exp ( Q_{|x|}  K_1^T ) }{ \sum_{m=1}^{|x|} exp ( Q_{|x|} K_m^T ) } , \frac{ exp ( Q_{|x|}  K_2^T ) }{ \sum_{m=1}^{|x|} exp ( Q_{|x|} K_m^T ) },  \ldots,  \frac{ exp ( Q_{|x|} K_{|x|}^T ) }{ \sum_{m=1}^{|x|} exp ( Q_{|x|} K_m^T ) }  \\
\end{pmatrix}

Therefore,

 \displaystyle

(\text{softmax} (QK^T \odot M) )_{n,m} =  

\left\{
\begin{array}{ll}
\frac{ exp ( Q_{n}  K_m^T ) }{ \sum_{m=1}^{n} exp ( Q_{n} K_m^T ) } & (n \geq m)\\
0 & (n < m)
\end{array}
\right.

Let's consider the output  o_n of  \text{Attention} (X) at a certain time  n.

 \displaystyle

\text{Attention} (X) 

=

\begin{pmatrix}
o_1  \\
\vdots \\
o_{|x|}  \\
\end{pmatrix} 

= 

\begin{pmatrix}
\sum_{m=1}^{1}  \frac{ exp ( Q_1 K_m^T ) }{ \sum_{m=1}^{1} exp ( Q_1 K_m^T ) } V_m  \\
\vdots \\
\sum_{m=1}^{|x|}  \frac{ exp ( Q_{|x|} K_m^T ) }{ \sum_{m=1}^{|x|} exp ( Q_{|x|} K_m^T ) } V_m  \\
\end{pmatrix}

Therefore,  o_n can be expressed as:

 \displaystyle

o_n =  \sum_{m=1}^{n}  \frac{ exp ( Q_n K_m^T ) }{ \sum_{m=1}^{n} exp ( Q_n K_m^T ) } V_m 

\tag{8}

In Attention, only the current  Q_n is needed to calculate the output at time  n, but all previous values of keys and values must be computed, resulting in a computational complexity of  O(n) for inference.

However, with Retention, as shown below, calculations using keys and values up to time  n-1 can be separated and retained as state, allowing each time step to have a computational complexity of  O(1).

 \displaystyle 
o_n = Q_n^{\prime} \sum_{m=1}^{n} \gamma^{n-m} K_m^{\prime \dagger} V_m =  Q_n^{\prime} \left( \gamma \left( \sum_{m=1}^{n-1} \gamma^{n-1-m} K_m^{\prime \dagger}  V_m \right) + K_n^{\prime \dagger} V_n  \right) \\
=  Q_n^{\prime} \left( \gamma s_{n-1} + K_n^{\prime \dagger} V_n  \right)

That concludes this section. Next time, I will write about the theory and implementation of the relative position embedding xPos.