Skip to main content

Flash Attention

Flash Attention 2 provides significant speedups for transformer training by optimizing memory access patterns.

Requirements

Flash Attention 2 requires:
  • Linux operating system
  • NVIDIA GPU with CUDA support
  • flash-attn package installed
pip install flash-attn

Quick Start

aitraining llm --train \
  --model meta-llama/Llama-3.2-1B \
  --data-path ./data.jsonl \
  --project-name fast-model \
  --use-flash-attention-2

Python API

from autotrain.trainers.clm.params import LLMTrainingParams

params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./data.jsonl",
    project_name="fast-model",

    use_flash_attention_2=True,
)

Parameters

ParameterCLI FlagDefaultDescription
use_flash_attention_2--use-flash-attention-2FalseEnable Flash Attention 2
attn_implementation--attn-implementationNoneOverride attention: eager, sdpa, flash_attention_2

Attention Implementation Options

OptionDescription
eagerStandard PyTorch attention (default for some models)
sdpaScaled Dot Product Attention (PyTorch 2.0+)
flash_attention_2Flash Attention 2 (fastest, requires flash-attn)

Model Compatibility

Gemma models use eager attention by default. Flash Attention 2 is automatically disabled for Gemma models due to compatibility issues. The attn_implementation is forced to eager.

Supported Models

Model FamilyFlash Attention 2Notes
LlamaYesFull support
MistralYesFull support
QwenYesFull support
PhiYesFull support
GemmaNoUses eager attention

With Quantization

Combine Flash Attention with quantization for maximum efficiency:
params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-8B",
    data_path="./data.jsonl",
    project_name="fast-quantized",

    peft=True,
    quantization="int4",
    use_flash_attention_2=True,
)
aitraining llm --train \
  --model meta-llama/Llama-3.2-8B \
  --data-path ./data.jsonl \
  --project-name fast-quantized \
  --peft \
  --quantization int4 \
  --use-flash-attention-2

With Sequence Packing

Flash Attention enables efficient sequence packing:
params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./data.jsonl",
    project_name="packed-model",

    use_flash_attention_2=True,
    packing=True,
)
Sequence packing requires Flash Attention to be enabled.

Performance Benefits

ConfigurationMemorySpeed
Standard attentionBaselineBaseline
SDPA~15% less~20% faster
Flash Attention 2~40% less~2x faster
Results vary by model size, sequence length, and hardware.

Troubleshooting

Installation Errors

If pip install flash-attn fails:
# Ensure CUDA toolkit is installed
nvcc --version

# Install with specific CUDA version
pip install flash-attn --no-build-isolation

Runtime Errors

“Flash Attention is not available”
  • Verify flash-attn is installed: python -c "import flash_attn"
  • Ensure you’re on Linux with CUDA
  • Check GPU compute capability (requires SM 80+, e.g., A100, H100)
Model uses eager attention despite flag
  • Some models (like Gemma) force eager attention
  • Check model documentation for compatibility

Next Steps