Pular para o conteúdo principal

Flash Attention

O Flash Attention 2 fornece acelerações significativas para o treinamento de transformers otimizando padrões de acesso à memória.

Requisitos

Flash Attention 2 requer:
  • Sistema operacional Linux
  • GPU NVIDIA com suporte CUDA
  • Pacote flash-attn instalado
pip install flash-attn

Início Rápido

aitraining llm --train \
  --model meta-llama/Llama-3.2-1B \
  --data-path ./data.jsonl \
  --project-name fast-model \
  --use-flash-attention-2

Python API

from autotrain.trainers.clm.params import LLMTrainingParams

params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./data.jsonl",
    project_name="fast-model",

    use_flash_attention_2=True,
)

Parâmetros

ParâmetroFlag CLIPadrãoDescrição
use_flash_attention_2--use-flash-attention-2FalseHabilitar Flash Attention 2
attn_implementation--attn-implementationNoneSobrescrever atenção: eager, sdpa, flash_attention_2

Opções de Implementação de Atenção

OpçãoDescrição
eagerAtenção padrão PyTorch (padrão para alguns modelos)
sdpaScaled Dot Product Attention (PyTorch 2.0+)
flash_attention_2Flash Attention 2 (mais rápido, requer flash-attn)

Compatibilidade de Modelos

Modelos Gemma usam atenção eager por padrão. Flash Attention 2 é automaticamente desabilitado para modelos Gemma devido a problemas de compatibilidade. O attn_implementation é forçado para eager.

Modelos Suportados

Família de ModeloFlash Attention 2Notas
LlamaSimSuporte completo
MistralSimSuporte completo
QwenSimSuporte completo
PhiSimSuporte completo
GemmaNãoUsa atenção eager

Com Quantização

Combine Flash Attention com quantização para máxima eficiência:
params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-8B",
    data_path="./data.jsonl",
    project_name="fast-quantized",

    peft=True,
    quantization="int4",
    use_flash_attention_2=True,
)
aitraining llm --train \
  --model meta-llama/Llama-3.2-8B \
  --data-path ./data.jsonl \
  --project-name fast-quantized \
  --peft \
  --quantization int4 \
  --use-flash-attention-2

Com Sequence Packing

Flash Attention permite sequence packing eficiente:
params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./data.jsonl",
    project_name="packed-model",

    use_flash_attention_2=True,
    packing=True,
)
Sequence packing requer que Flash Attention esteja habilitado.

Benefícios de Performance

ConfiguraçãoMemóriaVelocidade
Atenção padrãoBaselineBaseline
SDPA~15% menos~20% mais rápido
Flash Attention 2~40% menos~2x mais rápido
Resultados variam por tamanho do modelo, comprimento da sequência e hardware.

Solução de Problemas

Erros de Instalação

Se pip install flash-attn falhar:
# Certifique-se de que o toolkit CUDA está instalado
nvcc --version

# Instalar com versão CUDA específica
pip install flash-attn --no-build-isolation

Erros de Runtime

“Flash Attention não está disponível”
  • Verifique se flash-attn está instalado: python -c "import flash_attn"
  • Certifique-se de estar no Linux com CUDA
  • Verifique capacidade de computação da GPU (requer SM 80+, ex: A100, H100)
Modelo usa atenção eager apesar da flag
  • Alguns modelos (como Gemma) forçam atenção eager
  • Verifique documentação do modelo para compatibilidade

Próximos Passos