超参数扫描
自动搜索最佳超参数。
快速开始
aitraining llm --train \
--model google/gemma-3-270m \
--data-path ./data.jsonl \
--project-name sweep-experiment \
--use-sweep \
--sweep-backend optuna \
--sweep-n-trials 20
Python API
from autotrain.trainers.clm.params import LLMTrainingParams
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="sweep-experiment",
# Enable sweep
use_sweep=True,
sweep_backend="optuna",
sweep_n_trials=20,
sweep_metric="eval_loss",
sweep_direction="minimize",
# Base parameters (sweep will vary some)
trainer="sft",
epochs=3,
batch_size=4,
lr=2e-5,
)
| 参数 | 描述 | 默认值 |
|---|
use_sweep | 启用扫描 | False |
sweep_backend | 后端(optuna、grid、random) | optuna |
sweep_n_trials | 试验次数 | 10 |
sweep_metric | 要优化的指标 | eval_loss |
sweep_direction | 最小化或最大化 | minimize |
sweep_params | 自定义搜索空间(JSON 字符串) | None(自动) |
搜索空间
默认搜索空间
默认情况下,扫描搜索:
- 学习率:1e-5 到 1e-3(对数均匀)
- 批量大小:2、4、8、16(分类)
- 预热比例:0.0 到 0.2(均匀)
LoRA 秩不包含在默认扫描中。如果需要,请通过 sweep_params 手动添加。
自定义搜索空间
sweep_params 参数期望一个 JSON 字符串:
import json
sweep_params = json.dumps({
"lr": {"type": "loguniform", "low": 1e-6, "high": 1e-3},
"batch_size": {"type": "categorical", "values": [2, 4, 8]},
"lora_r": {"type": "categorical", "values": [8, 16, 32, 64]},
"warmup_ratio": {"type": "uniform", "low": 0.0, "high": 0.2},
})
params = LLMTrainingParams(
...
use_sweep=True,
sweep_params=sweep_params, # JSON string
)
扫描后端
Optuna
高效的贝叶斯优化:
params = LLMTrainingParams(
...
use_sweep=True,
sweep_backend="optuna",
)
网格搜索
对所有组合进行穷举搜索:
params = LLMTrainingParams(
...
use_sweep=True,
sweep_backend="grid",
)
随机搜索
从搜索空间随机采样:
params = LLMTrainingParams(
...
use_sweep=True,
sweep_backend="random",
)
标准指标
| 指标 | 描述 |
|---|
eval_loss | 验证损失 |
train_loss | 训练损失 |
accuracy | 分类准确率 |
perplexity | 语言模型困惑度 |
增强评估指标
启用 use_enhanced_eval 以访问其他指标:
| 指标 | 描述 |
|---|
perplexity | 语言模型困惑度(默认) |
bleu | 翻译/生成的 BLEU 分数 |
rouge | 摘要的 ROUGE 分数 |
bertscore | 语义相似度的 BERTScore |
accuracy | 分类准确率 |
f1 | F1 分数 |
exact_match | 精确匹配准确率 |
meteor | METEOR 分数 |
增强评估参数
| 参数 | 描述 | 默认值 |
|---|
use_enhanced_eval | 启用增强指标 | False |
eval_metrics | 逗号分隔的指标 | "perplexity" |
eval_strategy | 何时评估(epoch、steps、no) | "epoch" |
eval_batch_size | 评估批量大小 | 8 |
eval_dataset_path | 评估数据集路径(如果不同) | None |
eval_save_predictions | 在评估期间保存预测 | False |
eval_benchmark | 运行标准基准测试 | None |
标准基准测试
使用 eval_benchmark 运行标准 LLM 基准测试:
| 基准测试 | 描述 |
|---|
mmlu | 大规模多任务语言理解 |
hellaswag | HellaSwag 常识推理 |
arc | AI2 推理挑战 |
truthfulqa | TruthfulQA 事实性 |
自定义指标示例
params = LLMTrainingParams(
...
use_sweep=True,
sweep_metric="bleu",
use_enhanced_eval=True,
eval_metrics="bleu,rouge,bertscore",
eval_batch_size=8,
)
示例:找到最佳 LR
import json
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="lr-sweep",
use_sweep=True,
sweep_n_trials=10,
sweep_params=json.dumps({
"lr": {"type": "loguniform", "low": 1e-6, "high": 1e-3},
}),
# Fixed parameters
trainer="sft",
epochs=1,
batch_size=4,
)
查看结果
Optuna 仪表板
pip install optuna-dashboard
optuna-dashboard sqlite:///optuna.db
W&B 仪表板
在 W&B Web 界面中查看扫描。
最佳实践
- 从小开始 - 初始探索 10-20 次试验
- 使用早停 - 尽早停止不良试验
- 固定已知内容 - 仅扫描不确定的参数
- 使用验证数据 - 始终有评估分割
下一步