Flash Attention
Flash Attention 2 proporciona aceleraciones significativas para el entrenamiento de transformers optimizando patrones de acceso a memoria.
Requisitos
Flash Attention 2 requiere:
- Sistema operativo Linux
- GPU NVIDIA con soporte CUDA
- Paquete
flash-attn instalado
Inicio Rápido
aitraining llm --train \
--model meta-llama/Llama-3.2-1B \
--data-path ./data.jsonl \
--project-name fast-model \
--use-flash-attention-2
Python API
from autotrain.trainers.clm.params import LLMTrainingParams
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./data.jsonl",
project_name="fast-model",
use_flash_attention_2=True,
)
Parámetros
| Parámetro | Flag CLI | Por Defecto | Descripción |
|---|
use_flash_attention_2 | --use-flash-attention-2 | False | Habilitar Flash Attention 2 |
attn_implementation | --attn-implementation | None | Sobrescribir atención: eager, sdpa, flash_attention_2 |
Opciones de Implementación de Atención
| Opción | Descripción |
|---|
eager | Atención estándar PyTorch (por defecto para algunos modelos) |
sdpa | Scaled Dot Product Attention (PyTorch 2.0+) |
flash_attention_2 | Flash Attention 2 (más rápido, requiere flash-attn) |
Compatibilidad de Modelos
Los modelos Gemma usan atención eager por defecto. Flash Attention 2 se deshabilita automáticamente para modelos Gemma debido a problemas de compatibilidad. El attn_implementation se fuerza a eager.
Modelos Soportados
| Familia de Modelo | Flash Attention 2 | Notas |
|---|
| Llama | Sí | Soporte completo |
| Mistral | Sí | Soporte completo |
| Qwen | Sí | Soporte completo |
| Phi | Sí | Soporte completo |
| Gemma | No | Usa atención eager |
Con Cuantización
Combina Flash Attention con cuantización para máxima eficiencia:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-8B",
data_path="./data.jsonl",
project_name="fast-quantized",
peft=True,
quantization="int4",
use_flash_attention_2=True,
)
aitraining llm --train \
--model meta-llama/Llama-3.2-8B \
--data-path ./data.jsonl \
--project-name fast-quantized \
--peft \
--quantization int4 \
--use-flash-attention-2
Con Sequence Packing
Flash Attention permite sequence packing eficiente:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./data.jsonl",
project_name="packed-model",
use_flash_attention_2=True,
packing=True,
)
Sequence packing requiere que Flash Attention esté habilitado.
Beneficios de Rendimiento
| Configuración | Memoria | Velocidad |
|---|
| Atención estándar | Línea base | Línea base |
| SDPA | ~15% menos | ~20% más rápido |
| Flash Attention 2 | ~40% menos | ~2x más rápido |
Los resultados varían según tamaño del modelo, longitud de secuencia y hardware.
Solución de Problemas
Errores de Instalación
Si pip install flash-attn falla:
# Asegúrate de que el toolkit CUDA está instalado
nvcc --version
# Instalar con versión CUDA específica
pip install flash-attn --no-build-isolation
Errores de Runtime
“Flash Attention no está disponible”
- Verifica que flash-attn está instalado:
python -c "import flash_attn"
- Asegúrate de estar en Linux con CUDA
- Verifica capacidad de computación de GPU (requiere SM 80+, ej: A100, H100)
Modelo usa atención eager a pesar de la flag
- Algunos modelos (como Gemma) fuerzan atención eager
- Verifica documentación del modelo para compatibilidad
Próximos Pasos