In a paper of RetNet, regarded as a successor to Transformer, particularly in Chapter Two, the architecture of RetNet is explained.
However, the formula in the paper is a little confusing. In this post, the details of formula is explained while filling in the gaps.
*This is based on my understanding, so please comment if something seems incorrect.
Retentive Network
RetNet, similar to Transformer, is structured with a stack of identical blocks. Each block contains a Multi-Scale Retention and Feed Forward Network module, with RMS Norm and Residual Connection performed before and after these modules. Essentially, each block of RetNet is akin to a Transformer block, but with RMS Norm replacing Layer Norm, Multi-Scale Retention substituted by Multi-head Attention, and also performs relative position embedding, hence not requiring absolute position embedding. (Please refer to the paper for more detailed information.)
RetNet computes contextually embedded vectors ] when input vectors are embedded into .
Retention
The Retention mechanism has two forms: recursive and parallel, allowing for parallel training while making recursive inferences.
Given , a trainable matrix is used to compute . For a token at time , the value (the th row of ) and the state at the previous time , , are used to model a mapping that computes the th token of the transformed sequence.
In RetNet, this mapping is represented recursively as follows:
Here, query and key, like value, are determined using trainable matrices , with follows.
are respectively the th row vectors of , corresponding to the query and key for token .
The second equation of (1) may be less intuitive. To understand it, let's sequentially compute from the first equation:
From here, the general formula for can be derived from the first recursive formula:
Setting , we derive .
Next, the square matrix is assumed to be diagonalizable with a regular matrix and , such that . For those who want to review diagonalization, see wiki. Here, although , is represented as follows. This topic is also discussed in github's issue.
Consequently, , and can be computed as:
Thus, the second equation of (1) can be expressed as:
In the formula, is multiplied by the trainable parameters , so instead of being separate parameters, they are absorbed into and learned together.
Therefore, it is represented as:
Here, being a diagonal matrix, is used. , are known as xPos embeddings. (Interestingly, the authors of xPos and RetNet are almost the same. It seems the team from Seika University that developed xPos was recruited by Microsoft Research.)
For further simplification, treating as a scalar leads to:
The definition of is as follows:
In this definition, although , can be confusing. Therefore, redefining
and setting ], can be expressed as: