Referencia del SDK de Python
Guía completa de la API de Python de AITraining.
Instalación
pip install aitraining torch torchvision torchaudio
Nombre del Paquete vs Import: Instala con pip install aitraining, pero importa con from autotrain import ...
Entrenamiento LLM
LLMTrainingParams
from autotrain.trainers.clm.params import LLMTrainingParams
params = LLMTrainingParams(
# Required
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Training method
trainer="sft", # sft, dpo, orpo, ppo, reward, distillation
# Training settings
epochs=3,
batch_size=2,
lr=2e-5,
gradient_accumulation=4,
mixed_precision="bf16",
# LoRA
peft=True,
lora_r=16,
lora_alpha=32,
lora_dropout=0.05,
# Data processing
text_column="text",
block_size=2048,
add_eos_token=True,
save_processed_data="auto", # auto, local, hub, both, none
# Logging
log="wandb",
logging_steps=-1, # Default: -1 (auto)
# Hyperparameter sweep (optional)
# use_sweep=True,
# sweep_backend="optuna",
# sweep_n_trials=20,
# sweep_params='{"lr": {"type": "loguniform", "low": 1e-5, "high": 1e-3}}',
)
Parámetros Clave
| Parámetro | Tipo | Descripción |
|---|
model | str | Nombre o ruta del modelo |
data_path | str | Ruta a los datos de entrenamiento |
project_name | str | Directorio de salida |
trainer | str | Método de entrenamiento |
epochs | int | Número de épocas |
batch_size | int | Tamaño del lote |
lr | float | Tasa de aprendizaje |
peft | bool | Habilitar LoRA |
lora_r | int | Rango de LoRA |
lora_alpha | int | Alpha de LoRA |
save_processed_data | str | Guardar datos procesados: auto, local, hub, both, none |
Parámetros de Sweep de Hiperparámetros
| Parámetro | Tipo | Descripción |
|---|
use_sweep | bool | Habilitar sweep de hiperparámetros |
sweep_backend | str | Backend: optuna, grid, random |
sweep_n_trials | int | Número de intentos |
sweep_metric | str | Métrica a optimizar |
sweep_direction | str | minimize o maximize |
sweep_params | str | Espacio de búsqueda personalizado (cadena JSON) |
wandb_sweep | bool | Habilitar dashboard nativo de sweeps W&B |
wandb_sweep_project | str | Proyecto W&B para sweep |
wandb_sweep_entity | str | Entidad W&B (equipo/usuario) |
wandb_sweep_id | str | ID de sweep existente para continuar |
Se soportan formatos de lista y diccionario:
import json
# Formato dict (recomendado)
sweep_params = json.dumps({
"lr": {"type": "loguniform", "low": 1e-5, "high": 1e-3},
"batch_size": {"type": "categorical", "values": [2, 4, 8]},
})
# Formato lista (abreviado para categóricos)
sweep_params = json.dumps({
"batch_size": [2, 4, 8],
})
Tipos soportados: categorical, loguniform, uniform, int.
Clasificación de Texto
TextClassificationParams
from autotrain.trainers.text_classification.params import TextClassificationParams
params = TextClassificationParams(
model="bert-base-uncased",
data_path="./reviews.csv",
project_name="sentiment",
text_column="text",
target_column="label",
epochs=5,
batch_size=16,
lr=2e-5,
)
Clasificación de Imagen
ImageClassificationParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
params = ImageClassificationParams(
model="google/vit-base-patch16-224",
data_path="./images/",
project_name="classifier",
image_column="image",
target_column="label",
epochs=10,
batch_size=32,
)
Ejecución del Proyecto
AutoTrainProject
from autotrain.project import AutoTrainProject
# Create and run project
project = AutoTrainProject(
params=params,
backend="local", # "local" or "spaces"
process=True # Start training immediately
)
job_id = project.create()
Opciones de Backend
| Backend | Descripción |
|---|
local | Ejecutar en máquina local |
spaces-* | Ejecutar en Hugging Face Spaces (ej.: spaces-a10g-large, spaces-t4-medium) |
ep-* | Hugging Face Endpoints |
ngc-* | NVIDIA NGC |
nvcf-* | NVIDIA Cloud Functions |
Inferencia
Usando Completers
from autotrain.generation import CompletionConfig, create_completer
# Configure generation
config = CompletionConfig(
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
top_k=50,
)
# Create completer (first param is "model", not "model_path")
completer = create_completer(
model="./my-trained-model",
completer_type="message",
config=config
)
# Generate (returns MessageCompletionResult)
result = completer.chat("Hello, how are you?")
print(result.content) # Access the text content
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model
model = AutoModelForCausalLM.from_pretrained("./my-model")
tokenizer = AutoTokenizer.from_pretrained("./my-model")
# Generate
inputs = tokenizer("Hello!", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
Manejo de Datasets
AutoTrainDataset
from autotrain.dataset import AutoTrainDataset
dataset = AutoTrainDataset(
train_data=["train.csv"],
task="text_classification",
token="hf_...",
project_name="my-project",
username="my-username",
column_mapping={
"text": "review_text",
"label": "sentiment"
},
)
# Prepare dataset
data_path = dataset.prepare()
Archivos de Configuración
Cargar desde YAML
from autotrain.parser import AutoTrainConfigParser
# Parse config file
parser = AutoTrainConfigParser("config.yaml")
# Run training
parser.run()
Manejo de Errores
from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams
try:
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
)
project = AutoTrainProject(params=params, backend="local", process=True)
job_id = project.create()
except ValueError as e:
print(f"Configuration error: {e}")
except RuntimeError as e:
print(f"Training error: {e}")
Ejemplos Completos
Pipeline de Entrenamiento SFT
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
def train_sft():
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./conversations.jsonl",
project_name="gemma-sft",
trainer="sft",
epochs=3,
batch_size=2,
gradient_accumulation=8,
lr=2e-5,
peft=True,
lora_r=16,
lora_alpha=32,
log="wandb",
)
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
return project.create()
if __name__ == "__main__":
job_id = train_sft()
print(f"Job ID: {job_id}")
Pipeline de Entrenamiento DPO
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
def train_dpo():
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./preferences.jsonl",
project_name="llama-dpo",
trainer="dpo",
dpo_beta=0.1,
max_prompt_length=128, # Default: 128
max_completion_length=None, # Default: None
epochs=1,
batch_size=2,
lr=5e-6,
peft=True,
lora_r=16,
)
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
return project.create()
Próximos Pasos