Skip to main content

Hyperparameter Sweeps

Automatically search for the best hyperparameters.

Quick Start

aitraining llm --train \
  --model google/gemma-3-270m \
  --data-path ./data.jsonl \
  --project-name sweep-experiment \
  --use-sweep \
  --sweep-backend optuna \
  --sweep-n-trials 20

Python API

from autotrain.trainers.clm.params import LLMTrainingParams

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="sweep-experiment",

    # Enable sweep
    use_sweep=True,
    sweep_backend="optuna",
    sweep_n_trials=20,
    sweep_metric="eval_loss",
    sweep_direction="minimize",

    # Base parameters (sweep will vary some)
    trainer="sft",
    epochs=3,
    batch_size=4,
    lr=2e-5,
)

Parameters

ParameterDescriptionDefault
use_sweepEnable sweepingFalse
sweep_backendBackend (optuna, grid, random)optuna
sweep_n_trialsNumber of trials10
sweep_metricMetric to optimizeeval_loss
sweep_directionminimize or maximizeminimize
sweep_paramsCustom search space (JSON string)None (auto)
post_trial_scriptShell script to run after each trialNone
wandb_sweepEnable W&B native sweep dashboardFalse
wandb_sweep_projectW&B project for sweepproject_name
wandb_sweep_entityW&B entity (team/username)None
wandb_sweep_idExisting sweep ID to continueNone
wandb_run_idW&B run ID to resume (for external sweep agents)None

Search Spaces

Default Search Space

By default, sweeps search over:
  • Learning rate: 1e-5 to 1e-3 (log uniform)
  • Batch size: 2, 4, 8, 16 (categorical)
  • Warmup ratio: 0.0 to 0.2 (uniform)
LoRA rank is NOT included in the default sweep. Add it manually via sweep_params if needed.

Custom Search Space

The sweep_params parameter expects a JSON string. Both list and dict formats are supported:
import json

# Dict format (recommended) - explicit type specification
sweep_params = json.dumps({
    "lr": {"type": "loguniform", "low": 1e-6, "high": 1e-3},
    "batch_size": {"type": "categorical", "values": [2, 4, 8]},
    "lora_r": {"type": "categorical", "values": [8, 16, 32, 64]},
    "warmup_ratio": {"type": "uniform", "low": 0.0, "high": 0.2},
    "epochs": {"type": "int", "low": 1, "high": 5},
})

# List format (shorthand) - for categorical values only
sweep_params = json.dumps({
    "batch_size": [2, 4, 8],
    "lora_r": [8, 16, 32, 64],
})

params = LLMTrainingParams(
    ...
    use_sweep=True,
    sweep_params=sweep_params,  # JSON string
)
Supported dict types:
TypeDescriptionParameters
categoricalChoose from listvalues: list of options
loguniformLog-uniform distributionlow, high
uniformUniform distributionlow, high
intInteger rangelow, high

Sweep Backends

Optuna

Efficient Bayesian optimization:
params = LLMTrainingParams(
    ...
    use_sweep=True,
    sweep_backend="optuna",
)
Exhaustive search over all combinations:
params = LLMTrainingParams(
    ...
    use_sweep=True,
    sweep_backend="grid",
)
Random sampling from search space:
params = LLMTrainingParams(
    ...
    use_sweep=True,
    sweep_backend="random",
)

Metrics

Standard Metrics

MetricDescription
eval_lossValidation loss
train_lossTraining loss
accuracyClassification accuracy
perplexityLanguage model perplexity

Enhanced Evaluation Metrics

Enable use_enhanced_eval to access additional metrics:
MetricDescription
perplexityLanguage model perplexity (default)
bleuBLEU score for translation/generation
rougeROUGE score for summarization
bertscoreBERTScore for semantic similarity
accuracyClassification accuracy
f1F1 score
exact_matchExact match accuracy
meteorMETEOR score

Enhanced Evaluation Parameters

ParameterDescriptionDefault
use_enhanced_evalEnable enhanced metricsFalse
eval_metricsComma-separated metrics"perplexity"
eval_strategyWhen to evaluate (epoch, steps, no)"epoch"
eval_batch_sizeBatch size for evaluation8
eval_dataset_pathPath to eval dataset (if different)None
eval_save_predictionsSave predictions during evalFalse
eval_benchmarkRun standard benchmarkNone

Standard Benchmarks

Use eval_benchmark to run standard LLM benchmarks:
BenchmarkDescription
mmluMassive Multitask Language Understanding
hellaswagHellaSwag commonsense reasoning
arcAI2 Reasoning Challenge
truthfulqaTruthfulQA factuality

Custom Metrics Example

params = LLMTrainingParams(
    ...
    use_sweep=True,
    sweep_metric="bleu",
    use_enhanced_eval=True,
    eval_metrics="bleu,rouge,bertscore",
    eval_batch_size=8,
)

Example: Find Best LR

import json

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="lr-sweep",

    use_sweep=True,
    sweep_n_trials=10,
    sweep_params=json.dumps({
        "lr": {"type": "loguniform", "low": 1e-6, "high": 1e-3},
    }),

    # Fixed parameters
    trainer="sft",
    epochs=1,
    batch_size=4,
)

Viewing Results

Optuna Dashboard

pip install optuna-dashboard
optuna-dashboard sqlite:///optuna.db

W&B Native Sweep Dashboard

By default, sweeps run locally and only log individual runs to W&B. Enable native W&B sweep integration to get aggregated views, parallel coordinates plots, and parameter importance analysis in a dedicated sweep dashboard.
Local vs W&B Sweeps: Without wandb_sweep=True, each trial logs as a separate W&B run. With wandb_sweep=True, all trials are grouped under a single sweep dashboard with unified visualizations.

Enabling W&B Sweeps

aitraining llm --train \
  --model google/gemma-3-270m \
  --data-path ./data.jsonl \
  --project-name sweep-experiment \
  --use-sweep \
  --sweep-backend optuna \
  --sweep-n-trials 20 \
  --log wandb \
  --wandb-sweep \
  --wandb-sweep-project my-sweep-project \
  --wandb-sweep-entity my-team
Or in Python:
params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="sweep-experiment",

    use_sweep=True,
    sweep_backend="optuna",
    sweep_n_trials=20,

    # Enable W&B native sweep
    log="wandb",
    wandb_sweep=True,
    wandb_sweep_project="my-sweep-project",
    wandb_sweep_entity="my-team",
)

W&B Sweep Parameters

ParameterDescriptionDefault
wandb_sweepEnable W&B native sweep dashboardFalse
wandb_sweep_projectW&B project name for sweepUses project_name
wandb_sweep_entityW&B entity (team/username)None (uses default)
wandb_sweep_idExisting sweep ID to continueNone (creates new)

Continuing an Existing Sweep

To add more trials to an existing sweep instead of creating a new one, pass the sweep ID:
# First run creates sweep (prints "Created W&B sweep: abc123xyz")
aitraining llm --train \
  --use-sweep --sweep-n-trials 10 \
  --wandb-sweep --wandb-sweep-project my-project

# Later, continue the same sweep with more trials
aitraining llm --train \
  --use-sweep --sweep-n-trials 10 \
  --wandb-sweep --wandb-sweep-project my-project \
  --wandb-sweep-id abc123xyz
If you don’t pass wandb_sweep_id, a new sweep is created every time. The sweep ID is printed in the logs when the sweep starts (look for “Created W&B sweep: ”).

Accessing the Sweep Dashboard

  1. Go to wandb.ai and open your project
  2. Click the Sweep icon (broom) in the left panel
  3. Select your sweep from the list

Built-in Visualizations

W&B automatically generates three visualizations:
VisualizationDescription
Parallel Coordinates PlotShows relationships between hyperparameters and metrics at a glance
Scatter PlotCompares all runs to identify patterns
Parameter ImportanceRanks which hyperparameters most affect your metric
Each panel has an Edit button to customize axes and behavior.
The parallel coordinates plot is especially useful for identifying which hyperparameter combinations lead to the best results. You can drag on any axis to filter runs.

Using with External W&B Sweep Agents

If you’re running AITraining from an external W&B sweep agent (not AITraining’s built-in sweep), use --wandb-run-id to resume the agent’s run instead of creating a duplicate:
# External W&B sweep agent calls AITraining with run ID
aitraining llm --train \
  --model google/gemma-3-270m \
  --data-path ./data.jsonl \
  --wandb-run-id $WANDB_RUN_ID \
  --lr $SWEEP_LR \
  --batch-size $SWEEP_BATCH_SIZE
When --wandb-run-id is set, AITraining automatically sets WANDB_RESUME=allow so the trainer resumes the specified run instead of creating a new one.

Important Notes

  • Requires W&B login: Run wandb login before using W&B sweeps
  • Sweep ID is logged: Look for “Created W&B sweep: ” in the logs
  • Trials are grouped: Each trial appears as a run with group={sweep_id} for aggregation
  • Optuna still manages search: W&B is for visualization only; Optuna/grid/random handles the actual hyperparameter search

Post-Trial Actions

Execute custom actions after each trial completes, such as committing checkpoints to git, sending notifications, or syncing to remote storage.

CLI Usage

aitraining llm --train \
  --model google/gemma-3-270m \
  --data-path ./data.jsonl \
  --project-name sweep-experiment \
  --use-sweep \
  --sweep-n-trials 10 \
  --post-trial-script 'echo "Trial $TRIAL_NUMBER completed with metric $TRIAL_METRIC_VALUE"'

Environment Variables

The post-trial script receives these environment variables:
VariableDescriptionExample
TRIAL_NUMBERTrial index (0-based)0, 1, 2
TRIAL_METRIC_VALUEMetric value for this trial0.234
TRIAL_IS_BESTWhether this is the best trial so fartrue or false
TRIAL_OUTPUT_DIROutput directory for the trial/path/to/sweep/trial_0
TRIAL_PARAMSTrial parameters as string{'lr': 0.0001, 'batch_size': 8}

Example: Git Commit Best Models

aitraining llm --train \
  --use-sweep \
  --sweep-n-trials 20 \
  --post-trial-script 'if [ "$TRIAL_IS_BEST" = "true" ]; then git add . && git commit -m "Best model: trial $TRIAL_NUMBER, metric $TRIAL_METRIC_VALUE"; fi'

Example: Slack Notification

aitraining llm --train \
  --use-sweep \
  --sweep-n-trials 10 \
  --post-trial-script 'curl -X POST -H "Content-type: application/json" --data "{\"text\":\"Trial $TRIAL_NUMBER: $TRIAL_METRIC_VALUE\"}" $SLACK_WEBHOOK_URL'

Python API with Callback

For more control, use the Python API with a callback function:
from autotrain.utils import HyperparameterSweep, SweepConfig, TrialInfo

def on_trial_complete(trial_info: TrialInfo):
    """Called after each trial completes."""
    print(f"Trial {trial_info.trial_number} completed")
    print(f"  Params: {trial_info.params}")
    print(f"  Metric: {trial_info.metric_value}")
    print(f"  Is best: {trial_info.is_best}")

    if trial_info.is_best:
        # Do something special for best trials
        save_best_model(trial_info.output_dir)

config = SweepConfig(
    parameters={"lr": (1e-5, 1e-3, "log_uniform")},
    n_trials=10,
    backend="optuna",
    post_trial_callback=on_trial_complete,
)

sweep = HyperparameterSweep(config, train_function)
result = sweep.run()

TrialInfo Fields

FieldTypeDescription
trial_numberintTrial index (0-based)
paramsDict[str, Any]Hyperparameters used in this trial
metric_valuefloatMetric value achieved
output_dirOptional[str]Path to trial output directory
is_bestboolWhether this is the best trial so far
all_metricsOptional[Dict[str, float]]All metrics if available
Post-trial actions are non-blocking. If a callback or script fails, a warning is logged but the sweep continues. This ensures that sweep progress isn’t lost due to callback errors.

Best Practices

  1. Start small - 10-20 trials for initial exploration
  2. Use early stopping - Stop bad trials early
  3. Fix what you know - Only sweep uncertain params
  4. Use validation data - Always have eval split
  5. Use post-trial scripts - Automate checkpointing or notifications

Next Steps