Flash Attention
Flash Attention 2 通过优化内存访问模式,为 transformer 训练提供显著的加速。
Flash Attention 2 需要:
- Linux 操作系统
- 支持 CUDA 的 NVIDIA GPU
- 已安装
flash-attn 包
快速开始
aitraining llm --train \
--model meta-llama/Llama-3.2-1B \
--data-path ./data.jsonl \
--project-name fast-model \
--use-flash-attention-2
Python API
from autotrain.trainers.clm.params import LLMTrainingParams
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./data.jsonl",
project_name="fast-model",
use_flash_attention_2=True,
)
| 参数 | CLI 标志 | 默认值 | 描述 |
|---|
use_flash_attention_2 | --use-flash-attention-2 | False | 启用 Flash Attention 2 |
attn_implementation | --attn-implementation | None | 覆盖注意力:eager、sdpa、flash_attention_2 |
注意力实现选项
| 选项 | 描述 |
|---|
eager | 标准 PyTorch 注意力(某些模型的默认值) |
sdpa | 缩放点积注意力(PyTorch 2.0+) |
flash_attention_2 | Flash Attention 2(最快,需要 flash-attn) |
模型兼容性
Gemma 模型默认使用 eager 注意力。 由于兼容性问题,Gemma 模型会自动禁用 Flash Attention 2。attn_implementation 被强制为 eager。
支持的模型
| 模型系列 | Flash Attention 2 | 备注 |
|---|
| Llama | 是 | 完全支持 |
| Mistral | 是 | 完全支持 |
| Qwen | 是 | 完全支持 |
| Phi | 是 | 完全支持 |
| Gemma | 否 | 使用 eager 注意力 |
与量化结合
将 Flash Attention 与量化结合以获得最大效率:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-8B",
data_path="./data.jsonl",
project_name="fast-quantized",
peft=True,
quantization="int4",
use_flash_attention_2=True,
)
aitraining llm --train \
--model meta-llama/Llama-3.2-8B \
--data-path ./data.jsonl \
--project-name fast-quantized \
--peft \
--quantization int4 \
--use-flash-attention-2
与序列打包结合
Flash Attention 支持高效的序列打包:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./data.jsonl",
project_name="packed-model",
use_flash_attention_2=True,
packing=True,
)
序列打包需要启用 Flash Attention。
性能优势
| 配置 | 内存 | 速度 |
|---|
| 标准注意力 | 基线 | 基线 |
| SDPA | ~15% 更少 | ~20% 更快 |
| Flash Attention 2 | ~40% 更少 | ~2x 更快 |
结果因模型大小、序列长度和硬件而异。
故障排除
安装错误
如果 pip install flash-attn 失败:
# 确保安装了 CUDA 工具包
nvcc --version
# 使用特定 CUDA 版本安装
pip install flash-attn --no-build-isolation
运行时错误
“Flash Attention 不可用”
- 验证 flash-attn 已安装:
python -c "import flash_attn"
- 确保您在 Linux 上且支持 CUDA
- 检查 GPU 计算能力(需要 SM 80+,例如 A100、H100)
模型尽管设置了标志仍使用 eager 注意力
- 某些模型(如 Gemma)强制使用 eager 注意力
- 检查模型文档以了解兼容性
下一步