Flash Attention
O Flash Attention 2 fornece acelerações significativas para o treinamento de transformers otimizando padrões de acesso à memória.
Requisitos
Flash Attention 2 requer:
- Sistema operacional Linux
- GPU NVIDIA com suporte CUDA
- Pacote
flash-attn instalado
Início Rápido
aitraining llm --train \
--model meta-llama/Llama-3.2-1B \
--data-path ./data.jsonl \
--project-name fast-model \
--use-flash-attention-2
Python API
from autotrain.trainers.clm.params import LLMTrainingParams
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./data.jsonl",
project_name="fast-model",
use_flash_attention_2=True,
)
Parâmetros
| Parâmetro | Flag CLI | Padrão | Descrição |
|---|
use_flash_attention_2 | --use-flash-attention-2 | False | Habilitar Flash Attention 2 |
attn_implementation | --attn-implementation | None | Sobrescrever atenção: eager, sdpa, flash_attention_2 |
Opções de Implementação de Atenção
| Opção | Descrição |
|---|
eager | Atenção padrão PyTorch (padrão para alguns modelos) |
sdpa | Scaled Dot Product Attention (PyTorch 2.0+) |
flash_attention_2 | Flash Attention 2 (mais rápido, requer flash-attn) |
Compatibilidade de Modelos
Modelos Gemma usam atenção eager por padrão. Flash Attention 2 é automaticamente desabilitado para modelos Gemma devido a problemas de compatibilidade. O attn_implementation é forçado para eager.
Modelos Suportados
| Família de Modelo | Flash Attention 2 | Notas |
|---|
| Llama | Sim | Suporte completo |
| Mistral | Sim | Suporte completo |
| Qwen | Sim | Suporte completo |
| Phi | Sim | Suporte completo |
| Gemma | Não | Usa atenção eager |
Combine Flash Attention com quantização para máxima eficiência:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-8B",
data_path="./data.jsonl",
project_name="fast-quantized",
peft=True,
quantization="int4",
use_flash_attention_2=True,
)
aitraining llm --train \
--model meta-llama/Llama-3.2-8B \
--data-path ./data.jsonl \
--project-name fast-quantized \
--peft \
--quantization int4 \
--use-flash-attention-2
Flash Attention permite sequence packing eficiente:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./data.jsonl",
project_name="packed-model",
use_flash_attention_2=True,
packing=True,
)
Sequence packing requer que Flash Attention esteja habilitado.
| Configuração | Memória | Velocidade |
|---|
| Atenção padrão | Baseline | Baseline |
| SDPA | ~15% menos | ~20% mais rápido |
| Flash Attention 2 | ~40% menos | ~2x mais rápido |
Resultados variam por tamanho do modelo, comprimento da sequência e hardware.
Solução de Problemas
Erros de Instalação
Se pip install flash-attn falhar:
# Certifique-se de que o toolkit CUDA está instalado
nvcc --version
# Instalar com versão CUDA específica
pip install flash-attn --no-build-isolation
Erros de Runtime
“Flash Attention não está disponível”
- Verifique se flash-attn está instalado:
python -c "import flash_attn"
- Certifique-se de estar no Linux com CUDA
- Verifique capacidade de computação da GPU (requer SM 80+, ex: A100, H100)
Modelo usa atenção eager apesar da flag
- Alguns modelos (como Gemma) forçam atenção eager
- Verifique documentação do modelo para compatibilidade
Próximos Passos