LLM Training API
Complete API reference for LLM training.LLMTrainingParams
The main configuration class for LLM training.Copy
from autotrain.trainers.clm.params import LLMTrainingParams
Basic Parameters
Copy
params = LLMTrainingParams(
# Core parameters (always specify these)
model="google/gemma-3-270m", # Default: "google/gemma-3-270m"
data_path="./data.jsonl", # Default: "data"
project_name="my-model", # Default: "project-name"
# Data splits
train_split="train", # Default: "train"
valid_split=None, # Default: None
max_samples=None, # Default: None (use all)
)
Trainer Selection
Copy
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
trainer="sft", # Default: "default" (pretraining). Options: sft, dpo, orpo, reward
)
Training Hyperparameters
Copy
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Core hyperparameters (showing defaults)
epochs=1, # Default: 1
batch_size=2, # Default: 2
lr=3e-5, # Default: 3e-5
warmup_ratio=0.1, # Default: 0.1
gradient_accumulation=4, # Default: 4
weight_decay=0.0, # Default: 0.0
max_grad_norm=1.0, # Default: 1.0
# Precision
mixed_precision=None, # Default: None (options: bf16, fp16, None)
# Optimization
optimizer="adamw_torch", # Default: adamw_torch
scheduler="linear", # Default: linear
seed=42, # Default: 42
)
PEFT/LoRA Configuration
Copy
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./data.jsonl",
project_name="my-model",
# Enable LoRA (default: False)
peft=True,
lora_r=16, # Default: 16
lora_alpha=32, # Default: 32
lora_dropout=0.05, # Default: 0.05
target_modules="all-linear", # Default: all-linear
# Quantization (optional)
quantization="int4", # Options: int4, int8, or None (default: None)
# Merge after training (default is True - LoRA merged automatically)
merge_adapter=True,
)
Data Processing
Copy
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Text processing
text_column="text",
block_size=-1, # Default: -1 (model default)
model_max_length=2048, # Default: 2048
add_eos_token=True, # Default: True
padding="right", # Default: "right"
# Chat format
chat_template=None, # Auto-detect or specify
apply_chat_template=True, # Default: True
# Efficiency
packing=None, # Default: None (set True to enable)
use_flash_attention_2=False, # Default: False
attn_implementation=None, # Default: None
)
DPO Parameters
Copy
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./preferences.jsonl",
project_name="my-model",
trainer="dpo",
# DPO-specific
dpo_beta=0.1, # Default: 0.1
max_prompt_length=128, # Default: 128
max_completion_length=None, # Default: None
# Reference model (optional)
model_ref=None, # Uses same as model if None
# Data columns (required for DPO)
prompt_text_column="prompt",
text_column="chosen",
rejected_text_column="rejected",
)
ORPO Parameters
Copy
params = LLMTrainingParams(
model="google/gemma-2-2b",
data_path="./preferences.jsonl",
project_name="my-model",
trainer="orpo",
# ORPO-specific
dpo_beta=0.1, # Default: 0.1
max_prompt_length=128, # Default: 128
max_completion_length=None, # Default: None
# Data columns (required for ORPO)
prompt_text_column="prompt",
text_column="chosen",
rejected_text_column="rejected",
)
Knowledge Distillation
Copy
params = LLMTrainingParams(
model="google/gemma-3-270m", # Student
teacher_model="google/gemma-2-2b", # Teacher
data_path="./prompts.jsonl",
project_name="distilled-model",
use_distillation=True,
distill_temperature=3.0, # Default: 3.0
distill_alpha=0.7, # Default: 0.7
distill_max_teacher_length=512, # Default: 512
)
Logging & Saving
Copy
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Logging
log="wandb", # wandb, tensorboard, or None (default: wandb)
logging_steps=-1, # Default: -1 (auto)
wandb_visualizer=True, # Terminal visualizer
wandb_token=None, # W&B API token (optional)
# Checkpointing
save_strategy="steps", # steps or epoch (default: epoch)
save_steps=500,
save_total_limit=1, # Default: 1
eval_strategy="steps",
)
Hub Integration
Copy
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Push to Hub
push_to_hub=True,
username="your-username",
token="hf_...",
)
Running Training
Copy
from autotrain.project import AutoTrainProject
# Create and run project
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
job_id = project.create()
Complete Example
Copy
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
# Full configuration
params = LLMTrainingParams(
# Model
model="meta-llama/Llama-3.2-1B",
project_name="llama-production",
# Data
data_path="./conversations.jsonl",
train_split="train",
valid_split="validation",
text_column="text",
block_size=2048,
# Training
trainer="sft",
epochs=3,
batch_size=2,
gradient_accumulation=8,
lr=2e-5,
warmup_ratio=0.1,
mixed_precision="bf16",
# LoRA
peft=True,
lora_r=32,
lora_alpha=64,
lora_dropout=0.05,
# Optimization
use_flash_attention_2=True,
packing=True,
auto_find_batch_size=True,
unsloth=False, # Use Unsloth for faster training
# Distribution (for multi-GPU)
distributed_backend=None, # None for auto (DDP), or "deepspeed"
# Logging
log="wandb",
logging_steps=-1,
save_strategy="steps",
save_steps=500,
save_total_limit=1,
# Hub
push_to_hub=True,
username="my-username",
token="hf_...",
)
# Run training
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
job_id = project.create()