Skip to main content

LLM Training API

Complete API reference for LLM training.

LLMTrainingParams

The main configuration class for LLM training.
from autotrain.trainers.clm.params import LLMTrainingParams

Basic Parameters

params = LLMTrainingParams(
    # Core parameters (always specify these)
    model="google/gemma-3-270m",       # Default: "google/gemma-3-270m"
    data_path="./data.jsonl",          # Default: "data"
    project_name="my-model",           # Default: "project-name"

    # Data splits
    train_split="train",               # Default: "train"
    valid_split=None,                  # Default: None
    max_samples=None,                  # Default: None (use all)
)

Trainer Selection

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    trainer="sft",  # Default: "default" (pretraining). Options: sft, dpo, orpo, reward
)

Training Hyperparameters

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Core hyperparameters (showing defaults)
    epochs=1,           # Default: 1
    batch_size=2,       # Default: 2
    lr=3e-5,            # Default: 3e-5
    warmup_ratio=0.1,   # Default: 0.1
    gradient_accumulation=4,  # Default: 4
    weight_decay=0.0,   # Default: 0.0
    max_grad_norm=1.0,  # Default: 1.0

    # Precision
    mixed_precision=None,  # Default: None (options: bf16, fp16, None)

    # Optimization
    optimizer="adamw_torch",  # Default: adamw_torch
    scheduler="linear",       # Default: linear
    seed=42,                  # Default: 42
)

PEFT/LoRA Configuration

params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./data.jsonl",
    project_name="my-model",

    # Enable LoRA (default: False)
    peft=True,
    lora_r=16,           # Default: 16
    lora_alpha=32,       # Default: 32
    lora_dropout=0.05,   # Default: 0.05
    target_modules="all-linear",  # Default: all-linear

    # Quantization (optional)
    quantization="int4",  # Options: int4, int8, or None (default: None)

    # Merge after training (default is True - LoRA merged automatically)
    merge_adapter=True,
)

Data Processing

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Text processing
    text_column="text",
    block_size=-1,            # Default: -1 (model default)
    model_max_length=2048,    # Default: 2048
    add_eos_token=True,       # Default: True
    padding="right",          # Default: "right"

    # Chat format
    chat_template=None,       # Auto-detect or specify
    apply_chat_template=True, # Default: True

    # Efficiency
    packing=None,             # Default: None (set True to enable)
    use_flash_attention_2=False,  # Default: False
    attn_implementation=None,     # Default: None
)

DPO Parameters

params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./preferences.jsonl",
    project_name="my-model",

    trainer="dpo",

    # DPO-specific
    dpo_beta=0.1,              # Default: 0.1
    max_prompt_length=128,     # Default: 128
    max_completion_length=None, # Default: None

    # Reference model (optional)
    model_ref=None,  # Uses same as model if None

    # Data columns (required for DPO)
    prompt_text_column="prompt",
    text_column="chosen",
    rejected_text_column="rejected",
)

ORPO Parameters

params = LLMTrainingParams(
    model="google/gemma-2-2b",
    data_path="./preferences.jsonl",
    project_name="my-model",

    trainer="orpo",

    # ORPO-specific
    dpo_beta=0.1,              # Default: 0.1
    max_prompt_length=128,     # Default: 128
    max_completion_length=None, # Default: None

    # Data columns (required for ORPO)
    prompt_text_column="prompt",
    text_column="chosen",
    rejected_text_column="rejected",
)

Knowledge Distillation

params = LLMTrainingParams(
    model="google/gemma-3-270m",           # Student
    teacher_model="google/gemma-2-2b",     # Teacher
    data_path="./prompts.jsonl",
    project_name="distilled-model",

    use_distillation=True,
    distill_temperature=3.0,   # Default: 3.0
    distill_alpha=0.7,         # Default: 0.7
    distill_max_teacher_length=512,  # Default: 512
)

Logging & Saving

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Logging
    log="wandb",  # wandb, tensorboard, or None (default: wandb)
    logging_steps=-1,     # Default: -1 (auto)
    wandb_visualizer=True,  # Terminal visualizer
    wandb_token=None,       # W&B API token (optional)

    # Checkpointing
    save_strategy="steps",  # steps or epoch (default: epoch)
    save_steps=500,
    save_total_limit=1,   # Default: 1
    eval_strategy="steps",
)

Hub Integration

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Push to Hub
    push_to_hub=True,
    username="your-username",
    token="hf_...",
)

Running Training

from autotrain.project import AutoTrainProject

# Create and run project
project = AutoTrainProject(
    params=params,
    backend="local",
    process=True
)

job_id = project.create()

Complete Example

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

# Full configuration
params = LLMTrainingParams(
    # Model
    model="meta-llama/Llama-3.2-1B",
    project_name="llama-production",

    # Data
    data_path="./conversations.jsonl",
    train_split="train",
    valid_split="validation",
    text_column="text",
    block_size=2048,

    # Training
    trainer="sft",
    epochs=3,
    batch_size=2,
    gradient_accumulation=8,
    lr=2e-5,
    warmup_ratio=0.1,
    mixed_precision="bf16",

    # LoRA
    peft=True,
    lora_r=32,
    lora_alpha=64,
    lora_dropout=0.05,

    # Optimization
    use_flash_attention_2=True,
    packing=True,
    auto_find_batch_size=True,
    unsloth=False,  # Use Unsloth for faster training

    # Distribution (for multi-GPU)
    distributed_backend=None,  # None for auto (DDP), or "deepspeed"

    # Logging
    log="wandb",
    logging_steps=-1,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=1,

    # Hub
    push_to_hub=True,
    username="my-username",
    token="hf_...",
)

# Run training
project = AutoTrainProject(
    params=params,
    backend="local",
    process=True
)
job_id = project.create()

Next Steps