Recently, the research by the Mamba team has been remarkable: researchers from universities such as Cornell and Princeton have successfully "distilled" the large Transformer model Llama into Mamba, and designed a new type of inference decoding algorithm, significantly improving the model's inference speed.

The goal of the researchers was to transform Llama into Mamba. Why do this? Because training a large model from scratch is costly, and although Mamba has garnered widespread attention since its inception, few teams actually train large-scale Mamba models themselves. Although there are some well-known variants on the market, such as AI21's Jamba and NVIDIA's Hybrid Mamba2, the numerous successful Transformer models contain a wealth of knowledge. If we can capture this knowledge while fine-tuning the Transformer into Mamba, the problem would be solved.

image.png

The research team combined various methods such as progressive distillation, supervised fine-tuning, and directional preference optimization to successfully achieve this goal. Notably, speed is crucial while ensuring performance is not compromised. Mamba has a clear advantage in long sequence inference, and although Transformer has inference acceleration solutions like speculative decoding, Mamba's unique structure cannot directly apply these solutions. Therefore, researchers designed a new algorithm specifically for Mamba, combining it with hardware characteristics to implement speculative decoding based on Mamba.

Ultimately, the researchers successfully converted Zephyr-7B and Llama-38B into linear RNN models with performance equivalent to the standard models before distillation. The entire training process used only 20 billion tokens, achieving results comparable to Mamba7B models trained from scratch with 1.2 trillion tokens and NVIDIA Hybrid Mamba2 models trained with 3.5 trillion tokens.

In terms of technical details, linear RNN is compatible with linear attention, allowing researchers to directly reuse the projection matrices in the attention mechanism and complete model construction through parameter initialization. Additionally, the research team froze the parameters of the MLP layers in the Transformer, gradually replaced the attention heads with linear RNN layers (i.e., Mamba), and processed the grouped query attention that shares keys and values across heads.

During the distillation process, a strategy of gradually replacing the attention layers was adopted. Supervised fine-tuning included two main methods: one based on word-level KL divergence and the other on sequence-level knowledge distillation. For the stage of tuning according to user preferences, the team utilized the Direct Preference Optimization (DPO) method, comparing the model's outputs with those of the teacher model to ensure the model better meets user expectations when generating content.

Next, the researchers began to apply Transformer's speculative decoding to the Mamba model. Speculative decoding can be simply understood as using a small model to generate multiple outputs, which are then verified by a large model. The small model runs quickly, generating multiple output vectors rapidly, while the large model evaluates the accuracy of these outputs, thereby improving the overall inference speed.

To achieve this process, the researchers designed an algorithm that uses the small model to generate K draft outputs each time, after which the large model verifies and returns the final output and cached intermediate states. This method worked well on GPUs, with Mamba2.8B achieving a 1.5x inference acceleration with an acceptance rate of 60%. Although the results varied on different GPU architectures, the research team further optimized the process by fusing kernels and adjusting the implementation, ultimately achieving the desired acceleration effect.

In the experimental phase, researchers used Zephyr-7B and Llama-3Instruct8B for a three-stage distillation training, successfully replicating the research results in just 3 to 4 days on 8 cards of 80G A100. This study not only demonstrated the transformation path between Mamba and Llama but also provided new insights for future improvements in model inference speed and performance.

Paper link: https://arxiv.org/pdf/2408.15237