Skip to main content

Python SDK Reference

Comprehensive guide to the AITraining Python API.

Installation

pip install aitraining torch torchvision torchaudio
Package vs Import Name: Install with pip install aitraining, but import with from autotrain import ...

LLM Training

LLMTrainingParams

from autotrain.trainers.clm.params import LLMTrainingParams

params = LLMTrainingParams(
    # Required
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Training method
    trainer="sft",  # sft, dpo, orpo, ppo, reward, distillation

    # Training settings
    epochs=3,
    batch_size=2,
    lr=2e-5,
    gradient_accumulation=4,
    mixed_precision="bf16",

    # LoRA
    peft=True,
    lora_r=16,
    lora_alpha=32,
    lora_dropout=0.05,

    # Data processing
    text_column="text",
    block_size=2048,
    add_eos_token=True,

    # Logging
    log="wandb",
    logging_steps=-1,  # Default: -1 (auto)
)

Key Parameters

ParameterTypeDescription
modelstrModel name or path
data_pathstrPath to training data
project_namestrOutput directory
trainerstrTraining method
epochsintNumber of epochs
batch_sizeintBatch size
lrfloatLearning rate
peftboolEnable LoRA
lora_rintLoRA rank
lora_alphaintLoRA alpha

Text Classification

TextClassificationParams

from autotrain.trainers.text_classification.params import TextClassificationParams

params = TextClassificationParams(
    model="bert-base-uncased",
    data_path="./reviews.csv",
    project_name="sentiment",
    text_column="text",
    target_column="label",
    epochs=5,
    batch_size=16,
    lr=2e-5,
)

Image Classification

ImageClassificationParams

from autotrain.trainers.image_classification.params import ImageClassificationParams

params = ImageClassificationParams(
    model="google/vit-base-patch16-224",
    data_path="./images/",
    project_name="classifier",
    image_column="image",
    target_column="label",
    epochs=10,
    batch_size=32,
)

Project Execution

AutoTrainProject

from autotrain.project import AutoTrainProject

# Create and run project
project = AutoTrainProject(
    params=params,
    backend="local",  # "local" or "spaces"
    process=True      # Start training immediately
)

job_id = project.create()

Backend Options

BackendDescription
localRun on local machine
spaces-*Run on Hugging Face Spaces (e.g., spaces-a10g-large, spaces-t4-medium)
ep-*Hugging Face Endpoints
ngc-*NVIDIA NGC
nvcf-*NVIDIA Cloud Functions

Inference

Using Completers

from autotrain.generation import CompletionConfig, create_completer

# Configure generation
config = CompletionConfig(
    max_new_tokens=256,
    temperature=0.7,
    top_p=0.95,
    top_k=50,
)

# Create completer (first param is "model", not "model_path")
completer = create_completer(
    model="./my-trained-model",
    completer_type="message",
    config=config
)

# Generate (returns MessageCompletionResult)
result = completer.chat("Hello, how are you?")
print(result.content)  # Access the text content

Using Transformers Directly

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model
model = AutoModelForCausalLM.from_pretrained("./my-model")
tokenizer = AutoTokenizer.from_pretrained("./my-model")

# Generate
inputs = tokenizer("Hello!", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))

Dataset Handling

AutoTrainDataset

from autotrain.dataset import AutoTrainDataset

dataset = AutoTrainDataset(
    train_data=["train.csv"],
    task="text_classification",
    token="hf_...",
    project_name="my-project",
    username="my-username",
    column_mapping={
        "text": "review_text",
        "label": "sentiment"
    },
)

# Prepare dataset
data_path = dataset.prepare()

Configuration Files

Loading from YAML

from autotrain.parser import AutoTrainConfigParser

# Parse config file
parser = AutoTrainConfigParser("config.yaml")

# Run training
parser.run()

Error Handling

from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams

try:
    params = LLMTrainingParams(
        model="google/gemma-3-270m",
        data_path="./data.jsonl",
        project_name="my-model",
    )

    project = AutoTrainProject(params=params, backend="local", process=True)
    job_id = project.create()

except ValueError as e:
    print(f"Configuration error: {e}")
except RuntimeError as e:
    print(f"Training error: {e}")

Complete Examples

SFT Training Pipeline

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_sft():
    params = LLMTrainingParams(
        model="google/gemma-3-270m",
        data_path="./conversations.jsonl",
        project_name="gemma-sft",
        trainer="sft",
        epochs=3,
        batch_size=2,
        gradient_accumulation=8,
        lr=2e-5,
        peft=True,
        lora_r=16,
        lora_alpha=32,
        log="wandb",
    )

    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

if __name__ == "__main__":
    job_id = train_sft()
    print(f"Job ID: {job_id}")

DPO Training Pipeline

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_dpo():
    params = LLMTrainingParams(
        model="meta-llama/Llama-3.2-1B",
        data_path="./preferences.jsonl",
        project_name="llama-dpo",
        trainer="dpo",
        dpo_beta=0.1,
        max_prompt_length=128,     # Default: 128
        max_completion_length=None,  # Default: None
        epochs=1,
        batch_size=2,
        lr=5e-6,
        peft=True,
        lora_r=16,
    )

    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

Next Steps