izmyonの日記

奈良の山奥で研究にいそしむ大学院生の学習記録。

今日の論文2023/06/04,05:RWKV: Reinventing RNNs for the Transformer Era

RWKV: Reinventing RNNs for the Transformer Era

arxiv.org

©2022 The Authors

License: Creative Commons Attribution 4.0 International License(CC-BY)

本記事は、原著の一部を筆者が翻訳したものです。以下の図は、そこから引用しています。

This article is my translation of the part of the original publication. The following figures are taken from it.

要点まとめ

Transformerは、ほぼ全ての自然言語処理NLP)タスクに革命をもたらしたが、配列長に対して2次関数的にスケールするメモリと計算機の複雑さに悩まされている。 一方、リカレントニューラルネットワーク(RNN)は、メモリと計算機要件に線形スケーリングを示すが、並列化とスケーラビリティに限界があるため、Transformerと同等の性能を発揮することは難しい。我々は、Transformerの効率的な並列化トレーニングとRNNの効率的な推論を組み合わせた新しいモデルアーキテクチャ、Receptance Weighted Key Value(RWKV)を提案する。我々のアプローチは、線形注意メカニズムを活用し、モデルをトランスフォーマーまたはRNNとして定式化することを可能にし、学習時の計算を並列化し、推論時の計算およびメモリの複雑さを一定に保つことで、数百億のパラメータに拡張できる最初の非トランスフォーマーアーキテクチャを実現した。我々の実験では、RWKVは同規模のTransformerと同等の性能を発揮することが明らかになり、将来の研究がこのアーキテクチャを活用してより効率的なモデルを作成できることが示唆された。本研究は、シーケンス処理タスクにおける計算効率とモデルの性能のトレードオフを調整するための重要な一歩を提示した。

github.com

1 序論

ディープラーニング技術は、人工知能において大きな進歩を遂げ、様々な科学的・産業的アプリケーションにおいて極めて重要な役割を果たしている。これらのアプリケーションは、自然言語理解、会話AI、時系列分析、さらには画像やグラフなどのシーケンスとして再構成可能な間接的なモーダリティを含む、複雑なシーケンシャルデータ処理タスクを含むことが多い(Brown et al., 2020; Ismail Fawazet al., 2019; Wu et al., 2020; Albalak et al., 2022)これらの手法の中で主流なのは、RNN、畳み込みニューラルネットワーク(CNN)、およびTransformerモデル(Vaswani et al., 2017)である。

 これらはそれぞれ明確な欠点があり、特定のシナリオで効率が制限される。RNNは消失勾配問題に悩まされ、長いシーケンスの学習が困難である。さらに、RNNは学習中の時間次元での並列化ができないため、スケーラビリティが制限される(Hochreiter, 1998; Le and Zuidema, 2016)。一方、CNNは局所的なパターンを捉えることに長けているだけであり、多くのシーケンス処理タスクに重要な長距離依存性を扱う能力に限界がある(Bai et al., 2018)。

 Transformerモデルは、局所および長距離依存性の両方を扱う能力と並列化トレーニングの能力により、強力な代替案として浮上した(Tay et al., 2022)。GPT-3 (Brown et al., 2020)、ChatGPT (OpenAI, 2022; Koco ́n et al., 2023)、GPT-4 (Ope-nAI, 2023)、LLaMA (Touvron et al., 2023)、そしてChinchilla (Hoffmann et al., 2022) など最近のモデルはこのアーキテクチャの能力を実証し、NLPで何が可能かという境界を押し広げている。これらの重要な進歩にもかかわらず、Transformerに内在する自己注視メカニズムは、主にその2次的な複雑性によって、ユニークな課題を提起している。この複雑さにより、長い入力シーケンスを含むタスクやリソースに制約のある状況では、計算量とメモリ使用量が多いアーキテクチャになってしまいる。これらの制限により、Transformerのスケーリング特性を改善することを目的とした研究が活発に行われているが、多くの場合、Transformerを非常に効果的にするいくつかの特性が犠牲になっている(Wang et al., 2020; Zaheer et al., 2020; Dao et al.. 2022a)。

 これらの課題に取り組むために、RNNとTransformerの長所を効果的に組み合わせ、主要な欠点を回避する新しいアーキテクチャ、 Receptance Weighted Key Value(RWKV)を紹介する。 RWKVは、Transformer(Katharopoulos et al., 2020)に関連するメモリボトルネックと2次スケーリングをより効率的な線形スケーリングで緩和するように慎重に設計されているが、Transformerをこの分野で優位なアーキテクチャにしている豊かで印象的な特性は依然として維持される。

 RWKVの特徴のひとつは、Transformerのような並列トレーニングやロバストスケーラビリティを提供できることである。さらに、RWKVのアテンション機構は、従来のドット積トークンの相互作用を排除し、より効果的なチャネル指向のアテンションを採用した、線形アテンションのバリエーションを導入するように再構築されている。このアプローチは、従来のTransformerアーキテクチャでは、特定のトークンとのインタラクションがアテンションを支配していたのとは大きく異なるものである。RWKVの線形アテンションの実装は、近似処理なしで行われるため、効率が大幅に改善され、スケーラビリティが向上している(表1参照)。

 RWKV開発の包括的な動機は、ニューラルネットワークアーキテクチャにおける計算効率と表現力のギャップを埋めることである。RWKVは、数十億のパラメータを持つ大規模なモデルを扱うタスクに対して、計算コストの数分の一で実用的な性能を発揮する有望なソリューションを提供する。我々の実験結果は、RWKVが、様々なドメイン、特に逐次データ処理を含むAIモデルのスケーリングとデプロイメントにおける継続的な課題に対処するための貴重なツールになり得ることを示唆している。 このように、RWKVは、シーケンス処理タスクのための、より実現可能で計算効率の高い次世代のAIモデルへの道を切り開いたのである。

本論文における我々の貢献は、以下の通りである。

  • RNNとTransformerの長所を兼ね備え、それらの既知の制限を緩和するRWKVネットワークアーキテクチャを紹介する。

  • 私たちは、標準的なトランスフォーマーモデルに関連する2次的な複雑さを解消し、線形アテンションをもたらす新しいアテンションメカニズムの再定式化を提案する。

  • 我々は、大規模なモデルや長距離の依存関係を含むタスクを管理するRWKVの性能、効率、スケーリングを示すために、ベンチマークデータセットの包括的な一連の実験を実施した。

  • 1億6900万から140億のパラメータを持つ事前学習済みモデルをPile(Gao et al, 2020)上で公開した。

4 The Receptance Weighted Key Value (RWKV) Mode

RWKVアーキテクチャの名前は、タイムミキシングブロックとチャンネルミキシングブロックで使用される4つの主要なモデル要素に由来する。

  • R: レセプタンス(Receptance)ベクトルが、過去の情報の受容として作用する。

  • W: ウェイト(Weight)は、位置の重み減衰ベクトルで、学習可能なモデルパラメータである。

  • K: キー(Key)は、従来のアテンションのKに類似したベクトルである。

  • V: バリュー(Value)は、従来のアテンションのVに類似したベクトルである。

図2に示すように、各タイミングステップにおける主要な要素間の相互作用は乗算的である。

4.1 ハイレベルサマリー

RWKVアーキテクチャは,時間混合サブブロックとチャネル混合サブブロックから構成される、一連の積み重ねられた残差ブロックから構成され、それぞれ再帰構造を持つ。

 再帰構造は、現在の入力と前の時間ステップの入力の間の線形補間(図3の対角線で示す、時間シフト混合またはトークンシフトと呼ぶ技術)として定式化され、入力埋め込みの線形投影ごとに独立に調整できる(たとえば 時間混合におけるR、K、V、R、チャネル混合におけるKなど)、さらに式14で定式化されるWKVの時間依存の更新として調整される。WKVの計算はAFT(Zhai et al., 2021)に似ているが、WはAFTのペアワイズ行列ではなく、相対位置を乗じたチャネルワイズベクトルとなる。また、Wの潜在的な退化を補うために、現在のトークンに個別に注目するためのベクトルUを導入している(詳細は付録Gを参照)。

タイムミキシングブロックは次式で与えられる:

ここで,WKVの計算 wkv_tは、Transformerの Attn(Q,K,V)の役割を果たし、スカラー間の相互作用であるため2次コストを発生させない。直感的には、時間が長くなるにつれて、ベクトル o_tは増加する項の総和で表される長い歴史に依存することになる。RWKVは目標位置 tに対して、 [1,t ]の位置区間で重み付け和を行い、レセプタンス σ(r)と掛け合わせる。したがって、相互作用はあるタイムステップ内では乗法的であり、異なるタイムステップでは総和的である。さらに、チャネルミキシングブロックは以下のように与えられる:

ここで、二乗ReLU活性化を採用する(So et al., 2021)。タイムミキシングでもチャンネルミキシングでも、レセプタンスのシグモイドを取ることで、直感的に不要な履歴情報を排除する「忘却ゲート」として利用していることに注意。

4.2 Transformerのような並列化

RWKVは、Transformerを彷彿とさせるような、いわゆる時間並列モードで効率的に並列化することが可能である。1つのレイヤーでシーケンスのバッチを処理する時間的複雑さは[tex: O(BTd2)]である、 これは主に行列の乗算 W_i, i \in \{r,k,v,o \}(B個のシーケンス、T個の最大トークン、d個のチャンネルを仮定)からなる。一方、注目スコア wkv_tの更新にはシリアルスキャンが必要であり(詳細は付録Bを参照)、複雑度は O(BTd)である。

 行列乗算は典型的なTransformerで W_i, i \in \{Q, K, V, O \} に並列化することができる。 要素ごとのWKV計算は時間に依存するが、他の2つの次元に沿って容易に並列化できる(Lei et al., 2018)さらに、トークンシフトは、各ブロックで時間次元の単純なオフセットとしてPyTorch (Paszke et al., 2019) ライブラリのnn.ZeroPad2d*1を使って実装される。

4.3 RNNのような逐次的デコーディング

リカレントネットワークでは、状態 tの出力を状態 t+1の入力として使用することが一般的である。特に言語モデルの自己回帰推論では、各トークンが次のステップに進む前に計算される必要があり、RWKVは時系列モードと呼ばれるRNNのような構造を利用することが可能である。このような状況では、RWKVは、付録Bに示すように、推論中の復号化のために再帰的に構成することができ、各出力トークンは、シーケンス長に関係なく、一定の大きさの最新の状態にのみ依存するという利点を活用する。

 これは、RNNデコーダとして振る舞い、シーケンス長に対して一定の速度とメモリフットプリントをもたらし、より長いシーケンスを効率的に処理できる。一方、自己アテンションは、一般的に、シーケンス長に対して直線的に成長するKVキャッシュを必要とし、シーケンスが長くなるにつれて効率が低下し、メモリフットプリントと時間が増加する結果となる。

4.4 ソフトウェア実装

RWKVはもともとPytorch Deep Learning Library(Paszke et al., 2019)と4.7で説明するWKV計算用のカスタムCUDAカーネルを用いて実装されている。RWKVは一般的なリカレントネットワークであるが、現在の実装では言語モデリング(RWKV-LM)のタスクに焦点を当てている。モデルアーキテクチャは、4.7節で説明した埋め込み層と、4.6節で説明した原則に従って図2および図3に示すような複数の同じ残渣ブロックを順次適用することで構成されている。最後のブロックの後、LayerNorm (Ba et al.,2016) と線形射影で構成される単純な出力射影ヘッドを使用して、次のトークン予測タスクで使用するロジットを取得し、トレーニング中のクロスエントロピー損失を計算する。最後の残渣ブロックの後に生成された埋め込みとロジットは両方とも、後で下流NLPタスクに使用することもできる。学習は時間並列モード(セクション4.2)で行われ、自動進行推論と潜在的なチャットインターフェースは時系列モード(セクション4.3)を利用する。

4.5 勾配の安定性とレイヤースタッキング

RWKVアーキテクチャは、TransformersとRNNの両方の融合として設計されており、従来のRNNと比較して安定した勾配とTransformersの深いアーキテクチャーという利点を提供しながら、推論において効率的である。

 以前の研究では、RNNにおける勾配の安定性の問題に取り組むために、非飽和活性化関数を使用する(Chandar et al, 2019)、ゲーティング機構(Gu et al., 2019)、勾配クリッピング(Pascanu et al., 2012)、および制約の追加(Kanai et al., 2017; Miller and Hardt, 2018)などさまざまな技術を用いてきた。 しかし、RWKVはソフトマックスをRNN形式の更新と併用することで、この問題を本質的に回避している。

 RWKVモデルは、アテンションのようなスコアを更新するためのシングルステッププロセスを特徴とし、数値的安定性を助け、消失する勾配から保護する時間依存のソフトマックス演算を含む(厳密な証明は付録Fを参照)。直感的には、この操作は勾配が最も関連性の高い経路に沿って伝搬されることを保証する。レイヤー正規化(Ba et al., 2016)は、勾配を安定させることでディープニューラルネットワークの学習ダイナミクスを強化し、消失勾配と爆発勾配の両方の問題に対処するアーキテクチャのもう一つの重要な側面である。

 これらの設計要素は、RWKVアーキテクチャの安定性と学習能力に貢献するだけでなく、既存のRNNの能力を超える方法で複数の層を積層することができる。これにより、様々な抽象化レベルにおいて、より複雑なパターンを捉えることができるようになった(付録Gも参照)。

4.6 シーケンシャルなデータ処理に時間構造を利用する

RWKVは、再帰性、時間減衰、トークンシフトという3つのメカニズムの組み合わせにより、連続した情報を捉え、伝播させる。

 RWKVの時間混合ブロックにある再帰性は、シーケンス要素間の複雑な関係を捕捉し、時間を通じて局所的な情報を伝播するモデルの能力の基礎となるものである。

 このモデルは、過去の情報の影響を時間経過とともに徐々に減少させることで、逐次処理に不可欠な時間的な位置関係や進行の感覚を維持している。このような順序データにおける位置情報の扱いは、線形バイアスが入力長の外挿を容易にするAttention with Linear Biases (ALiBi) モデル (Press et al., 2022) と類似している。この文脈から、RWKVアーキテクチャは、ALiBiの訓練可能なバージョンとして認識することができ、明示的な符号化の必要なく、位置情報をシームレスに取り込むことができる。また、Zhaiら(2021)が導入したゲートコンボリューションを、あるステップまでのシーケンスの全長に拡張したものと見ることもできる。

 トークンシフトとタイムシフトの混合(図3の斜めの矢印)も、シーケンシャルデータへのモデルの適応に寄与している。 現在の入力と前のタイムステップの入力を線形補間することで、モデルは入力チャンネルの情報を自然に集約し、ゲート化する。タイムシフトミキシングの全体的な構造は、時系列データの予測に使われる古典的なアーキテクチャであるWaveNet (van den Oord et al., 2016)の減衰のない因果関係畳み込みに類似している。

4.7 追加の最適化

カスタムカーネル:標準的な深層学習フレームワークを使用した場合のタスクの連続的な性質によるWKV計算の非効率性に対処するため、トレーニンアクセラレータで単一の計算カーネルを起動するように、カスタムCUDAカーネルを実装した。それ以外の部分はすべて行列の乗算やポイントワイズ演算で、効率的に並列化することが可能である。

RゲートによるFFN:先行研究(Tolstikhin et al., 2021; Liu et al.,2021; Yu et al., 2022)は、Transformerベースの視覚タスクにおいて、セルフアテンションが以前考えられていたほど必須ではない可能性を示唆している。 しかし、自然言語タスクにおいて自己注意を完全に置き換えることは、あまりに思い切った方法である可能性がある。そこで、本研究では、固定的なQKV式をKVに置き換え、新たに時間減衰因子Wを導入することで、注意メカニズムを部分的に解体することにした。このアプローチにより、MLP-mixer(Tolstikhinet al., 2021)に似たトークンとチャンネル混合コンポーネントと、gMLP(Liu et al., 2021)に似たゲーティングユニットRを組み込むことができ、我々のRWKVモデルの性能を向上させる。

小さな初期埋め込みトランスフォーマーモデル(Vaswani et al, 2017)のトレーニングの初期段階ででは、埋め込み行列がゆっくりと変化しないことが観察され、モデルが初期のノイズの多い埋め込み状態から逸脱することが課題となっている。この問題を軽減するために、我々は、埋め込み行列を小さな値で初期化し、その後、追加のLayerNorm操作を適用するアプローチを提案する。この手法を導入することで、学習プロセスを高速化・安定化し、post-LNのコンポーネントを持つディープアーキテクチャの学習を可能にする。この手法の有効性は図8で示されており、モデルが初期の小さな埋め込みから素早く移行することで、コンバージェンスが向上することが示されている。これは、シングルステップの後の小さな変化によって達成され、その結果、方向が大幅に変化し、その後、LayerNorm操作後に大きく変化する。

カスタム初期化:先行研究(He et al., 2016; Jumper et al., 2021)の原則に基づき、対称性を崩しながらパラメータをできるだけ同一性マッピングに近い値に初期化して、きれいな情報経路を確保する。ほとんどの重みはゼロに初期化される。線形層にはバイアスは使用されない。具体的な計算式は付録Dに記載されている。具体的な計算式は付録Dに記載されている。初期化の選択は、収束の速度と品質に大きな影響を与えることがわかった(付録E参照)。

5 評価

本節では、次のような疑問に対する評価に焦点を当てる:

RQ1およびRQ2について、図4から、RWKVが6つのベンチマーク(Winogrande、PIQA、ARC-C、ARC-E、LAMBADA、SciQ)において、主要なオープンソース二次関数的複雑度のトランスフォーマーモデルに対して非常に優れていることがわかる: Pythia(Biderman et al., 2023)、OPT(Zhang et al., 2022)、BLOOM(Scao et al., 2022)。RWKVは、PIQA、OBQA、ARC-E、COPAの4つのタスクにおいて、PythiaとGPT-Neo (Black et al., 2022)をも凌駕している(詳細は付録Hを参照)。RQ3については、図5から、文脈の長さを長くするとPileでのテストロスが少なくなることがわかり、RWKVが長い文脈情報を効果的に利用できることがわかる。

6 推論実験

サイズとファミリーに応じた推論要件のベンチマークを実施した。 具体的には、CPU(x86)とGPUNVIDIA A100 80GB)を含む非定型の計算プラットフォームで、テキスト生成の速度とメモリ要件を評価した。すべての実験では、float32の精度を使用している。埋め込み層と非埋め込み層の両方を含む、すべてのモデルパラメータをパラメータ数に含めている。 異なる量子化設定における性能は、今後の研究に委ねられる。より多くの結果については、付録Iを参照してください。

 さらに、RWKV-4とChatGPT / GPT-4の比較研究を行った(付録J参照)。その結果、RWKV-4はプロンプトエンジニアリングに非常に敏感であることがわかった。 GPTで使用したプロンプトをRWKVに適したものに調整したところ、F1測定の性能は44.2%から74.8%へと大幅に向上した。

今後の展望

RWKVアーキテクチャの将来的な研究の方向性として、いくつかの有望なものがある。

  • 時間依存性を高めた定式化でモデルの表現力を高め、効率を維持したままモデルの初期状態を探索する。

  • RWKVの計算効率をさらに向上させるため、 wkv_tステップで並列スキャンを適用し、計算コストを O(B log(T) d) に削減する。

  • RWKVのエンコーダ・デコーダへの応用とクロスアテンションメカニズムの代替の可能性を調査している。これは、seq2seqやマルチモダルの設定に適用でき、学習と推論の両方の効率を向上させることができる。

  • RWKVの状態(またはコンテキスト)を活用することで、解釈のしやすさ、シーケンスデータの予測可能性、安全性を高めることができる。また、隠れ状態を操作することで、動作を誘導し、プロンプトチューニングによってカスタマイズ性を高めることができる。

  • 人間とのインタラクションを強化するために、特定のセットでファインチューンされたモデルを探索する(Ouyang et al., 2022)。特に興味深いのは、異なるデータセットや特定のユースケースにおける性能であろう。

  • LoRA(Hu et al., 2022)のようなパラメータ効率の良いファインチューン方法を採用し、提案アーキテクチャの異なる量子化スキームでの動作を特徴付ける。

8 結論

RWKVは、時間ベースの混合コンポーネントの可能性を利用したRNNモデルへの新しいアプローチである。 RWKVは、現在のアーキテクチャの限界に対処しながら、局所性と長距離依存性を捉えることができるいくつかの重要な戦略を導入している:(1)2次QKアテンションを線形コストを持つスカラー定式化で置き換える。(2)再帰と逐次帰納バイアスを再定義し、効率的な訓練並列化と効率的推論を可能にする。(3)カスタム初期化を使って訓練ダイナミクスを強化する。

 提案アーキテクチャを様々なNLPタスクでベンチマークし、SoTAに匹敵する性能とコスト削減を示した。 さらに、表現力、解釈力、スケーリングに関する実験により、モデルの能力を示し、RWKVと他の LLMとの動作の類似性を示す。

 RWKVは、連続するデータの複雑な関係をモデル化するためのスケーラブルで効率的なアーキテクチャの新しい扉を開く。Transformerの代替案が数多く提案されているが、数百億のパラメーターを持つ事前学習済みモデルでその主張を裏付けたのは、我々のものが初めてである。

*1:0,0,1,-1