Flash Attention
Flash Attention 2 provides significant speedups for transformer training by optimizing memory access patterns.
Requirements
Flash Attention 2 requires:
- Linux operating system
- NVIDIA GPU with CUDA support
flash-attn package installed
Quick Start
aitraining llm --train \
--model meta-llama/Llama-3.2-1B \
--data-path ./data.jsonl \
--project-name fast-model \
--use-flash-attention-2
Python API
from autotrain.trainers.clm.params import LLMTrainingParams
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./data.jsonl",
project_name="fast-model",
use_flash_attention_2=True,
)
Parameters
| Parameter | CLI Flag | Default | Description |
|---|
use_flash_attention_2 | --use-flash-attention-2 | False | Enable Flash Attention 2 |
attn_implementation | --attn-implementation | None | Override attention: eager, sdpa, flash_attention_2 |
Attention Implementation Options
| Option | Description |
|---|
eager | Standard PyTorch attention (default for some models) |
sdpa | Scaled Dot Product Attention (PyTorch 2.0+) |
flash_attention_2 | Flash Attention 2 (fastest, requires flash-attn) |
Model Compatibility
Gemma models use eager attention by default. Flash Attention 2 is automatically disabled for Gemma models due to compatibility issues. The attn_implementation is forced to eager.
Supported Models
| Model Family | Flash Attention 2 | Notes |
|---|
| Llama | Yes | Full support |
| Mistral | Yes | Full support |
| Qwen | Yes | Full support |
| Phi | Yes | Full support |
| Gemma | No | Uses eager attention |
With Quantization
Combine Flash Attention with quantization for maximum efficiency:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-8B",
data_path="./data.jsonl",
project_name="fast-quantized",
peft=True,
quantization="int4",
use_flash_attention_2=True,
)
aitraining llm --train \
--model meta-llama/Llama-3.2-8B \
--data-path ./data.jsonl \
--project-name fast-quantized \
--peft \
--quantization int4 \
--use-flash-attention-2
With Sequence Packing
Flash Attention enables efficient sequence packing:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./data.jsonl",
project_name="packed-model",
use_flash_attention_2=True,
packing=True,
)
Sequence packing requires Flash Attention to be enabled.
| Configuration | Memory | Speed |
|---|
| Standard attention | Baseline | Baseline |
| SDPA | ~15% less | ~20% faster |
| Flash Attention 2 | ~40% less | ~2x faster |
Results vary by model size, sequence length, and hardware.
Troubleshooting
Installation Errors
If pip install flash-attn fails:
# Ensure CUDA toolkit is installed
nvcc --version
# Install with specific CUDA version
pip install flash-attn --no-build-isolation
Runtime Errors
“Flash Attention is not available”
- Verify flash-attn is installed:
python -c "import flash_attn"
- Ensure you’re on Linux with CUDA
- Check GPU compute capability (requires SM 80+, e.g., A100, H100)
Model uses eager attention despite flag
- Some models (like Gemma) force eager attention
- Check model documentation for compatibility
Next Steps