Recientemente, la investigación del equipo Mamba ha llamado la atención: investigadores de universidades como Cornell y Princeton lograron "destilar" el modelo Transformer Llama en Mamba, y diseñaron un nuevo algoritmo de decodificación de inferencia que mejora significativamente la velocidad de inferencia del modelo.
El objetivo de los investigadores era convertir Llama en Mamba. ¿Por qué? Porque entrenar un modelo grande desde cero es costoso, y aunque Mamba ha recibido mucha atención desde su lanzamiento, pocos equipos entrenan modelos Mamba a gran escala por sí mismos. Si bien existen algunas variantes conocidas, como Jamba de AI21 y Hybrid Mamba2 de NVIDIA, los numerosos modelos Transformer exitosos contienen un vasto conocimiento. Si pudiéramos capturar este conocimiento y ajustar finamente el Transformer a Mamba, el problema se resolvería.
El equipo de investigación combinó varios métodos, incluyendo destilación progresiva, ajuste fino supervisado y optimización de preferencias dirigidas, logrando con éxito este objetivo. Cabe destacar que, además de mantener el rendimiento, la velocidad es crucial. Mamba tiene una ventaja significativa en la inferencia de secuencias largas, y los Transformer también tienen soluciones de aceleración de inferencia, como la decodificación predictiva. Debido a la estructura única de Mamba, estas soluciones no se podían aplicar directamente, por lo que los investigadores diseñaron un nuevo algoritmo y lo combinaron con las características del hardware para lograr la decodificación predictiva basada en Mamba.
Finalmente, los investigadores convirtieron con éxito Zephyr-7B y Llama-38B en modelos RNN lineales, con un rendimiento comparable a los modelos estándar antes de la destilación. El proceso de entrenamiento utilizó solo 20B tokens, con resultados comparables a los modelos Mamba7B entrenados desde cero con 1.2T tokens y al modelo NVIDIA Hybrid Mamba2 entrenado con 3.5T tokens.
En cuanto a los detalles técnicos, las RNN lineales y la atención lineal son similares, por lo que los investigadores pudieron reutilizar directamente las matrices de proyección del mecanismo de atención y construir el modelo mediante la inicialización de parámetros. Además, el equipo congeló los parámetros de la capa MLP del Transformer, reemplazando gradualmente las cabezas de atención con capas RNN lineales (es decir, Mamba), y procesando la atención de consulta de agrupación de claves y valores compartidos entre cabezas.
Durante el proceso de destilación, se utilizó una estrategia de reemplazo gradual de las capas de atención. El ajuste fino supervisado incluyó dos métodos principales: uno basado en la divergencia de KL a nivel de palabra, y otro en la destilación de conocimiento a nivel de secuencia. Para la etapa de ajuste según las preferencias del usuario, el equipo utilizó la optimización de preferencias directas (DPO), comparando con la salida del modelo maestro para asegurar que el modelo genere contenido que se ajuste mejor a las expectativas del usuario.
A continuación, los investigadores comenzaron a aplicar la decodificación predictiva del Transformer al modelo Mamba. La decodificación predictiva se puede entender simplemente como el uso de un modelo pequeño para generar múltiples salidas, y luego usar un modelo grande para verificar estas salidas. El modelo pequeño se ejecuta rápidamente, generando múltiples vectores de salida, mientras que el modelo grande evalúa la precisión de estas salidas, mejorando así la velocidad de inferencia general.
Para lograr este proceso, los investigadores diseñaron un algoritmo que utiliza el modelo pequeño para generar K salidas de borrador, y luego el modelo grande verifica y devuelve la salida final y el caché del estado intermedio. Este método funcionó bien en la GPU, con Mamba2.8B logrando una aceleración de la inferencia de 1.5 veces y una tasa de aceptación del 60%. Aunque los resultados variaron en diferentes arquitecturas de GPU, el equipo de investigación realizó más optimizaciones mediante la fusión de núcleos y el ajuste de la implementación, logrando finalmente el efecto de aceleración deseado.
En la fase experimental, los investigadores utilizaron Zephyr-7B y Llama-3Instruct8B para un entrenamiento de destilación de tres fases, y finalmente lograron reproducir los resultados en solo 3 a 4 días en 8 tarjetas A100 de 80G. Esta investigación no solo muestra el camino de la transformación de Transformer a Mamba, sino que también proporciona nuevas ideas para mejorar la velocidad y el rendimiento de los modelos futuros.
Dirección del artículo:https://arxiv.org/pdf/2408.15237