Python SDK Reference
Comprehensive guide to the AITraining Python API.
Installation
pip install aitraining torch torchvision torchaudio
Package vs Import Name: Install with pip install aitraining, but import with from autotrain import ...
LLM Training
LLMTrainingParams
from autotrain.trainers.clm.params import LLMTrainingParams
params = LLMTrainingParams(
# Required
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Training method
trainer="sft", # sft, dpo, orpo, ppo, reward, distillation
# Training settings
epochs=3,
batch_size=2,
lr=2e-5,
gradient_accumulation=4,
mixed_precision="bf16",
# LoRA
peft=True,
lora_r=16,
lora_alpha=32,
lora_dropout=0.05,
# Data processing
text_column="text",
block_size=2048,
add_eos_token=True,
# Logging
log="wandb",
logging_steps=-1, # Default: -1 (auto)
)
Key Parameters
| Parameter | Type | Description |
|---|
model | str | Model name or path |
data_path | str | Path to training data |
project_name | str | Output directory |
trainer | str | Training method |
epochs | int | Number of epochs |
batch_size | int | Batch size |
lr | float | Learning rate |
peft | bool | Enable LoRA |
lora_r | int | LoRA rank |
lora_alpha | int | LoRA alpha |
Text Classification
TextClassificationParams
from autotrain.trainers.text_classification.params import TextClassificationParams
params = TextClassificationParams(
model="bert-base-uncased",
data_path="./reviews.csv",
project_name="sentiment",
text_column="text",
target_column="label",
epochs=5,
batch_size=16,
lr=2e-5,
)
Image Classification
ImageClassificationParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
params = ImageClassificationParams(
model="google/vit-base-patch16-224",
data_path="./images/",
project_name="classifier",
image_column="image",
target_column="label",
epochs=10,
batch_size=32,
)
Project Execution
AutoTrainProject
from autotrain.project import AutoTrainProject
# Create and run project
project = AutoTrainProject(
params=params,
backend="local", # "local" or "spaces"
process=True # Start training immediately
)
job_id = project.create()
Backend Options
| Backend | Description |
|---|
local | Run on local machine |
spaces-* | Run on Hugging Face Spaces (e.g., spaces-a10g-large, spaces-t4-medium) |
ep-* | Hugging Face Endpoints |
ngc-* | NVIDIA NGC |
nvcf-* | NVIDIA Cloud Functions |
Inference
Using Completers
from autotrain.generation import CompletionConfig, create_completer
# Configure generation
config = CompletionConfig(
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
top_k=50,
)
# Create completer (first param is "model", not "model_path")
completer = create_completer(
model="./my-trained-model",
completer_type="message",
config=config
)
# Generate (returns MessageCompletionResult)
result = completer.chat("Hello, how are you?")
print(result.content) # Access the text content
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model
model = AutoModelForCausalLM.from_pretrained("./my-model")
tokenizer = AutoTokenizer.from_pretrained("./my-model")
# Generate
inputs = tokenizer("Hello!", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
Dataset Handling
AutoTrainDataset
from autotrain.dataset import AutoTrainDataset
dataset = AutoTrainDataset(
train_data=["train.csv"],
task="text_classification",
token="hf_...",
project_name="my-project",
username="my-username",
column_mapping={
"text": "review_text",
"label": "sentiment"
},
)
# Prepare dataset
data_path = dataset.prepare()
Configuration Files
Loading from YAML
from autotrain.parser import AutoTrainConfigParser
# Parse config file
parser = AutoTrainConfigParser("config.yaml")
# Run training
parser.run()
Error Handling
from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams
try:
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
)
project = AutoTrainProject(params=params, backend="local", process=True)
job_id = project.create()
except ValueError as e:
print(f"Configuration error: {e}")
except RuntimeError as e:
print(f"Training error: {e}")
Complete Examples
SFT Training Pipeline
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
def train_sft():
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./conversations.jsonl",
project_name="gemma-sft",
trainer="sft",
epochs=3,
batch_size=2,
gradient_accumulation=8,
lr=2e-5,
peft=True,
lora_r=16,
lora_alpha=32,
log="wandb",
)
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
return project.create()
if __name__ == "__main__":
job_id = train_sft()
print(f"Job ID: {job_id}")
DPO Training Pipeline
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
def train_dpo():
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./preferences.jsonl",
project_name="llama-dpo",
trainer="dpo",
dpo_beta=0.1,
max_prompt_length=128, # Default: 128
max_completion_length=None, # Default: None
epochs=1,
batch_size=2,
lr=5e-6,
peft=True,
lora_r=16,
)
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
return project.create()
Next Steps