最近、Mambaチームの研究が注目を集めています。コーネル大学やプリンストン大学などの研究者らが、大規模TransformerモデルであるLlamaを「蒸留」してMambaに変換することに成功し、新たな推論デコードアルゴリズムを設計することで、モデルの推論速度を大幅に向上させました。
研究者の目標は、LlamaをMambaに変換することでした。なぜでしょうか?それは、大規模モデルをゼロから訓練するには莫大なコストがかかるためです。Mambaは発表以来注目を集めていますが、実際には大規模なMambaモデルを独自に訓練しているチームはほとんどありません。AI21のJambaやNVIDIAのHybrid Mamba2など、有名な派生モデルはいくつか存在しますが、多くの成功したTransformerモデルには豊富な知識が蓄積されています。これらの知識を保持しつつ、TransformerをMambaに微調整できれば、問題は解決します。
研究チームは、漸進的蒸留、教師あり微調整、方向性のある選好最適化など、複数の方法を組み合わせることで、この目標を達成しました。注目すべきは、性能を落とさずに速度を向上させた点です。Mambaは長シーケンス推論において大きな優位性があり、Transformerにも推測デコードなどの推論高速化策がありますが、Mambaの独自の構造ではこれらの策を直接適用できません。そこで研究者らは、Mambaベースの推測デコードを実現するために、新たなアルゴリズムを設計し、ハードウェア特性も考慮しました。
最終的に、研究者らはZephyr-7BとLlama-38Bを線形RNNモデルに成功裏に変換し、蒸留前の標準モデルと同等の性能を実現しました。訓練プロセス全体で200億トークンしか使用せず、1.2兆トークンでゼロから訓練されたMamba7Bモデルや3.5兆トークンで訓練されたNVIDIA Hybrid Mamba2モデルと遜色ない結果となりました。
技術的な詳細としては、線形RNNと線形アテンションは共通点があるため、研究者らはアテンションメカニズム内の射影行列を直接再利用し、パラメータの初期化によってモデル構築を完了することができました。さらに、研究チームはTransformerのMLP層のパラメータを固定し、線形RNN層(つまりMamba)をアテンションヘッドに段階的に置き換え、ヘッド間で共有されるキーと値のグループ化クエリアテンションを処理しました。
蒸留プロセスでは、アテンション層を段階的に置き換える戦略を採用しました。教師あり微調整には、単語レベルのKLダイバージェンスに基づく方法と、シーケンスレベルの知識蒸留の2つの主要な方法があります。ユーザー選好への最適化段階では、教師モデルの出力と比較することで、モデルが生成するコンテンツがユーザーの期待によりよく合致するように、直接選好最適化(DPO)の方法を使用しました。
次に、研究者らはTransformerの推測デコードをMambaモデルに適用し始めました。推測デコードは、簡単に言えば、小さなモデルで複数の出力を生成し、大きなモデルでそれらの出力を検証することです。小さなモデルは高速に動作し、複数の出力ベクトルを迅速に生成できます。一方、大きなモデルはこれらの出力の正確性を評価し、全体的な推論速度を向上させます。
このプロセスを実現するために、研究者らはアルゴリズムを設計しました。このアルゴリズムでは、小さなモデルを使用してK個の下書き出力を生成し、その後、大きなモデルが検証を行い、最終的な出力と中間状態のキャッシュを返します。この方法はGPU上で良好な効果を示し、Mamba2.8Bでは推論速度が1.5倍に向上し、アクセプタンスレートは60%に達しました。異なるアーキテクチャのGPUでは効果が異なるものの、研究チームはカーネルの融合と実装方法の調整により、さらなる最適化を行い、最終的に理想的な速度向上を実現しました。
実験段階では、研究者らはZephyr-7BとLlama-3Instruct8Bを用いて3段階の蒸留訓練を行い、わずか8枚の80G A100 GPUで3~4日間の稼働で研究成果を再現することに成功しました。この研究は、MambaとLlama間の変換方法を示しただけでなく、将来のモデルの推論速度と性能向上のための新たなアイデアを提供しました。
論文アドレス:https://arxiv.org/pdf/2408.15237