跳转到主要内容

超参数扫描

自动搜索最佳超参数。

快速开始

aitraining llm --train \
  --model google/gemma-3-270m \
  --data-path ./data.jsonl \
  --project-name sweep-experiment \
  --use-sweep \
  --sweep-backend optuna \
  --sweep-n-trials 20

Python API

from autotrain.trainers.clm.params import LLMTrainingParams

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="sweep-experiment",

    # Enable sweep
    use_sweep=True,
    sweep_backend="optuna",
    sweep_n_trials=20,
    sweep_metric="eval_loss",
    sweep_direction="minimize",

    # Base parameters (sweep will vary some)
    trainer="sft",
    epochs=3,
    batch_size=4,
    lr=2e-5,
)

参数

参数描述默认值
use_sweep启用扫描False
sweep_backend后端(optunagridrandomoptuna
sweep_n_trials试验次数10
sweep_metric要优化的指标eval_loss
sweep_direction最小化或最大化minimize
sweep_params自定义搜索空间(JSON 字符串)None(自动)

搜索空间

默认搜索空间

默认情况下,扫描搜索:
  • 学习率:1e-5 到 1e-3(对数均匀)
  • 批量大小:2、4、8、16(分类)
  • 预热比例:0.0 到 0.2(均匀)
LoRA 秩不包含在默认扫描中。如果需要,请通过 sweep_params 手动添加。

自定义搜索空间

sweep_params 参数期望一个 JSON 字符串:
import json

sweep_params = json.dumps({
    "lr": {"type": "loguniform", "low": 1e-6, "high": 1e-3},
    "batch_size": {"type": "categorical", "values": [2, 4, 8]},
    "lora_r": {"type": "categorical", "values": [8, 16, 32, 64]},
    "warmup_ratio": {"type": "uniform", "low": 0.0, "high": 0.2},
})

params = LLMTrainingParams(
    ...
    use_sweep=True,
    sweep_params=sweep_params,  # JSON string
)

扫描后端

Optuna

高效的贝叶斯优化:
params = LLMTrainingParams(
    ...
    use_sweep=True,
    sweep_backend="optuna",
)

网格搜索

对所有组合进行穷举搜索:
params = LLMTrainingParams(
    ...
    use_sweep=True,
    sweep_backend="grid",
)

随机搜索

从搜索空间随机采样:
params = LLMTrainingParams(
    ...
    use_sweep=True,
    sweep_backend="random",
)

指标

标准指标

指标描述
eval_loss验证损失
train_loss训练损失
accuracy分类准确率
perplexity语言模型困惑度

增强评估指标

启用 use_enhanced_eval 以访问其他指标:
指标描述
perplexity语言模型困惑度(默认)
bleu翻译/生成的 BLEU 分数
rouge摘要的 ROUGE 分数
bertscore语义相似度的 BERTScore
accuracy分类准确率
f1F1 分数
exact_match精确匹配准确率
meteorMETEOR 分数

增强评估参数

参数描述默认值
use_enhanced_eval启用增强指标False
eval_metrics逗号分隔的指标"perplexity"
eval_strategy何时评估(epochstepsno"epoch"
eval_batch_size评估批量大小8
eval_dataset_path评估数据集路径(如果不同)None
eval_save_predictions在评估期间保存预测False
eval_benchmark运行标准基准测试None

标准基准测试

使用 eval_benchmark 运行标准 LLM 基准测试:
基准测试描述
mmlu大规模多任务语言理解
hellaswagHellaSwag 常识推理
arcAI2 推理挑战
truthfulqaTruthfulQA 事实性

自定义指标示例

params = LLMTrainingParams(
    ...
    use_sweep=True,
    sweep_metric="bleu",
    use_enhanced_eval=True,
    eval_metrics="bleu,rouge,bertscore",
    eval_batch_size=8,
)

示例:找到最佳 LR

import json

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="lr-sweep",

    use_sweep=True,
    sweep_n_trials=10,
    sweep_params=json.dumps({
        "lr": {"type": "loguniform", "low": 1e-6, "high": 1e-3},
    }),

    # Fixed parameters
    trainer="sft",
    epochs=1,
    batch_size=4,
)

查看结果

Optuna 仪表板

pip install optuna-dashboard
optuna-dashboard sqlite:///optuna.db

W&B 仪表板

在 W&B Web 界面中查看扫描。

最佳实践

  1. 从小开始 - 初始探索 10-20 次试验
  2. 使用早停 - 尽早停止不良试验
  3. 固定已知内容 - 仅扫描不确定的参数
  4. 使用验证数据 - 始终有评估分割

下一步