Skip to main content

Python API

AITraining provides a Python API for programmatic access to all training functionality.

Installation

pip install aitraining torch

Quick Start

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

# Configure training
params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",
    trainer="sft",
    epochs=3,
    batch_size=4,
    lr=2e-5,
    peft=True,
    lora_r=16,
)

# Start training
project = AutoTrainProject(params=params, backend="local", process=True)
job_id = project.create()
print(f"Training started: {job_id}")

API Structure

Training Parameters

Each task type has its own params class:
TaskParams Class
LLM TrainingLLMTrainingParams
Text ClassificationTextClassificationParams
Image ClassificationImageClassificationParams
Token ClassificationTokenClassificationParams
Seq2SeqSeq2SeqParams
TabularTabularParams
Object DetectionObjectDetectionParams
VLMVLMTrainingParams

Project Execution

from autotrain.project import AutoTrainProject

# Create project
project = AutoTrainProject(
    params=params,
    backend="local",  # or "spaces"
    process=True      # Start immediately
)

# Run training
job_id = project.create()

Example: Full Training Script

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_model():
    # Configure parameters
    params = LLMTrainingParams(
        # Model
        model="meta-llama/Llama-3.2-1B",
        project_name="llama-sft",

        # Data
        data_path="./conversations.jsonl",
        train_split="train",
        text_column="text",
        block_size=2048,

        # Training
        trainer="sft",
        epochs=3,
        batch_size=2,
        gradient_accumulation=4,
        lr=2e-5,
        mixed_precision="bf16",

        # LoRA
        peft=True,
        lora_r=16,
        lora_alpha=32,
        lora_dropout=0.05,

        # Logging
        log="wandb",
        logging_steps=10,
    )

    # Start training
    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

if __name__ == "__main__":
    job_id = train_model()
    print(f"Training complete: {job_id}")

Core Modules

ModuleDescription
autotrain.projectProject execution
autotrain.trainers.clm.paramsLLM parameters
autotrain.trainers.text_classification.paramsText classification
autotrain.datasetDataset handling
autotrain.generationInference utilities

Next Steps