跳转到主要内容

Flash Attention

Flash Attention 2 通过优化内存访问模式,为 transformer 训练提供显著的加速。

要求

Flash Attention 2 需要:
  • Linux 操作系统
  • 支持 CUDA 的 NVIDIA GPU
  • 已安装 flash-attn
pip install flash-attn

快速开始

aitraining llm --train \
  --model meta-llama/Llama-3.2-1B \
  --data-path ./data.jsonl \
  --project-name fast-model \
  --use-flash-attention-2

Python API

from autotrain.trainers.clm.params import LLMTrainingParams

params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./data.jsonl",
    project_name="fast-model",

    use_flash_attention_2=True,
)

参数

参数CLI 标志默认值描述
use_flash_attention_2--use-flash-attention-2False启用 Flash Attention 2
attn_implementation--attn-implementationNone覆盖注意力:eagersdpaflash_attention_2

注意力实现选项

选项描述
eager标准 PyTorch 注意力(某些模型的默认值)
sdpa缩放点积注意力(PyTorch 2.0+)
flash_attention_2Flash Attention 2(最快,需要 flash-attn)

模型兼容性

Gemma 模型默认使用 eager 注意力。 由于兼容性问题,Gemma 模型会自动禁用 Flash Attention 2。attn_implementation 被强制为 eager

支持的模型

模型系列Flash Attention 2备注
Llama完全支持
Mistral完全支持
Qwen完全支持
Phi完全支持
Gemma使用 eager 注意力

与量化结合

将 Flash Attention 与量化结合以获得最大效率:
params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-8B",
    data_path="./data.jsonl",
    project_name="fast-quantized",

    peft=True,
    quantization="int4",
    use_flash_attention_2=True,
)
aitraining llm --train \
  --model meta-llama/Llama-3.2-8B \
  --data-path ./data.jsonl \
  --project-name fast-quantized \
  --peft \
  --quantization int4 \
  --use-flash-attention-2

与序列打包结合

Flash Attention 支持高效的序列打包:
params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./data.jsonl",
    project_name="packed-model",

    use_flash_attention_2=True,
    packing=True,
)
序列打包需要启用 Flash Attention。

性能优势

配置内存速度
标准注意力基线基线
SDPA~15% 更少~20% 更快
Flash Attention 2~40% 更少~2x 更快
结果因模型大小、序列长度和硬件而异。

故障排除

安装错误

如果 pip install flash-attn 失败:
# 确保安装了 CUDA 工具包
nvcc --version

# 使用特定 CUDA 版本安装
pip install flash-attn --no-build-isolation

运行时错误

“Flash Attention 不可用”
  • 验证 flash-attn 已安装:python -c "import flash_attn"
  • 确保您在 Linux 上且支持 CUDA
  • 检查 GPU 计算能力(需要 SM 80+,例如 A100、H100)
模型尽管设置了标志仍使用 eager 注意力
  • 某些模型(如 Gemma)强制使用 eager 注意力
  • 检查模型文档以了解兼容性

下一步