Saltar al contenido principal

Flash Attention

Flash Attention 2 proporciona aceleraciones significativas para el entrenamiento de transformers optimizando patrones de acceso a memoria.

Requisitos

Flash Attention 2 requiere:
  • Sistema operativo Linux
  • GPU NVIDIA con soporte CUDA
  • Paquete flash-attn instalado
pip install flash-attn

Inicio Rápido

aitraining llm --train \
  --model meta-llama/Llama-3.2-1B \
  --data-path ./data.jsonl \
  --project-name fast-model \
  --use-flash-attention-2

Python API

from autotrain.trainers.clm.params import LLMTrainingParams

params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./data.jsonl",
    project_name="fast-model",

    use_flash_attention_2=True,
)

Parámetros

ParámetroFlag CLIPor DefectoDescripción
use_flash_attention_2--use-flash-attention-2FalseHabilitar Flash Attention 2
attn_implementation--attn-implementationNoneSobrescribir atención: eager, sdpa, flash_attention_2

Opciones de Implementación de Atención

OpciónDescripción
eagerAtención estándar PyTorch (por defecto para algunos modelos)
sdpaScaled Dot Product Attention (PyTorch 2.0+)
flash_attention_2Flash Attention 2 (más rápido, requiere flash-attn)

Compatibilidad de Modelos

Los modelos Gemma usan atención eager por defecto. Flash Attention 2 se deshabilita automáticamente para modelos Gemma debido a problemas de compatibilidad. El attn_implementation se fuerza a eager.

Modelos Soportados

Familia de ModeloFlash Attention 2Notas
LlamaSoporte completo
MistralSoporte completo
QwenSoporte completo
PhiSoporte completo
GemmaNoUsa atención eager

Con Cuantización

Combina Flash Attention con cuantización para máxima eficiencia:
params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-8B",
    data_path="./data.jsonl",
    project_name="fast-quantized",

    peft=True,
    quantization="int4",
    use_flash_attention_2=True,
)
aitraining llm --train \
  --model meta-llama/Llama-3.2-8B \
  --data-path ./data.jsonl \
  --project-name fast-quantized \
  --peft \
  --quantization int4 \
  --use-flash-attention-2

Con Sequence Packing

Flash Attention permite sequence packing eficiente:
params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./data.jsonl",
    project_name="packed-model",

    use_flash_attention_2=True,
    packing=True,
)
Sequence packing requiere que Flash Attention esté habilitado.

Beneficios de Rendimiento

ConfiguraciónMemoriaVelocidad
Atención estándarLínea baseLínea base
SDPA~15% menos~20% más rápido
Flash Attention 2~40% menos~2x más rápido
Los resultados varían según tamaño del modelo, longitud de secuencia y hardware.

Solución de Problemas

Errores de Instalación

Si pip install flash-attn falla:
# Asegúrate de que el toolkit CUDA está instalado
nvcc --version

# Instalar con versión CUDA específica
pip install flash-attn --no-build-isolation

Errores de Runtime

“Flash Attention no está disponible”
  • Verifica que flash-attn está instalado: python -c "import flash_attn"
  • Asegúrate de estar en Linux con CUDA
  • Verifica capacidad de computación de GPU (requiere SM 80+, ej: A100, H100)
Modelo usa atención eager a pesar de la flag
  • Algunos modelos (como Gemma) fuerzan atención eager
  • Verifica documentación del modelo para compatibilidad

Próximos Pasos