Saltar al contenido principal

Referencia del SDK de Python

Guía completa de la API de Python de AITraining.

Instalación

pip install aitraining torch torchvision torchaudio
Nombre del Paquete vs Import: Instala con pip install aitraining, pero importa con from autotrain import ...

Entrenamiento LLM

LLMTrainingParams

from autotrain.trainers.clm.params import LLMTrainingParams

params = LLMTrainingParams(
    # Required
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Training method
    trainer="sft",  # sft, dpo, orpo, ppo, reward, distillation

    # Training settings
    epochs=3,
    batch_size=2,
    lr=2e-5,
    gradient_accumulation=4,
    mixed_precision="bf16",

    # LoRA
    peft=True,
    lora_r=16,
    lora_alpha=32,
    lora_dropout=0.05,

    # Data processing
    text_column="text",
    block_size=2048,
    add_eos_token=True,

    # Logging
    log="wandb",
    logging_steps=-1,  # Default: -1 (auto)
)

Parámetros Clave

ParámetroTipoDescripción
modelstrNombre o ruta del modelo
data_pathstrRuta a los datos de entrenamiento
project_namestrDirectorio de salida
trainerstrMétodo de entrenamiento
epochsintNúmero de épocas
batch_sizeintTamaño del lote
lrfloatTasa de aprendizaje
peftboolHabilitar LoRA
lora_rintRango de LoRA
lora_alphaintAlpha de LoRA

Clasificación de Texto

TextClassificationParams

from autotrain.trainers.text_classification.params import TextClassificationParams

params = TextClassificationParams(
    model="bert-base-uncased",
    data_path="./reviews.csv",
    project_name="sentiment",
    text_column="text",
    target_column="label",
    epochs=5,
    batch_size=16,
    lr=2e-5,
)

Clasificación de Imagen

ImageClassificationParams

from autotrain.trainers.image_classification.params import ImageClassificationParams

params = ImageClassificationParams(
    model="google/vit-base-patch16-224",
    data_path="./images/",
    project_name="classifier",
    image_column="image",
    target_column="label",
    epochs=10,
    batch_size=32,
)

Ejecución del Proyecto

AutoTrainProject

from autotrain.project import AutoTrainProject

# Create and run project
project = AutoTrainProject(
    params=params,
    backend="local",  # "local" or "spaces"
    process=True      # Start training immediately
)

job_id = project.create()

Opciones de Backend

BackendDescripción
localEjecutar en máquina local
spaces-*Ejecutar en Hugging Face Spaces (ej.: spaces-a10g-large, spaces-t4-medium)
ep-*Hugging Face Endpoints
ngc-*NVIDIA NGC
nvcf-*NVIDIA Cloud Functions

Inferencia

Usando Completers

from autotrain.generation import CompletionConfig, create_completer

# Configure generation
config = CompletionConfig(
    max_new_tokens=256,
    temperature=0.7,
    top_p=0.95,
    top_k=50,
)

# Create completer (first param is "model", not "model_path")
completer = create_completer(
    model="./my-trained-model",
    completer_type="message",
    config=config
)

# Generate (returns MessageCompletionResult)
result = completer.chat("Hello, how are you?")
print(result.content)  # Access the text content

Usando Transformers Directamente

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model
model = AutoModelForCausalLM.from_pretrained("./my-model")
tokenizer = AutoTokenizer.from_pretrained("./my-model")

# Generate
inputs = tokenizer("Hello!", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))

Manejo de Datasets

AutoTrainDataset

from autotrain.dataset import AutoTrainDataset

dataset = AutoTrainDataset(
    train_data=["train.csv"],
    task="text_classification",
    token="hf_...",
    project_name="my-project",
    username="my-username",
    column_mapping={
        "text": "review_text",
        "label": "sentiment"
    },
)

# Prepare dataset
data_path = dataset.prepare()

Archivos de Configuración

Cargar desde YAML

from autotrain.parser import AutoTrainConfigParser

# Parse config file
parser = AutoTrainConfigParser("config.yaml")

# Run training
parser.run()

Manejo de Errores

from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams

try:
    params = LLMTrainingParams(
        model="google/gemma-3-270m",
        data_path="./data.jsonl",
        project_name="my-model",
    )

    project = AutoTrainProject(params=params, backend="local", process=True)
    job_id = project.create()

except ValueError as e:
    print(f"Configuration error: {e}")
except RuntimeError as e:
    print(f"Training error: {e}")

Ejemplos Completos

Pipeline de Entrenamiento SFT

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_sft():
    params = LLMTrainingParams(
        model="google/gemma-3-270m",
        data_path="./conversations.jsonl",
        project_name="gemma-sft",
        trainer="sft",
        epochs=3,
        batch_size=2,
        gradient_accumulation=8,
        lr=2e-5,
        peft=True,
        lora_r=16,
        lora_alpha=32,
        log="wandb",
    )

    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

if __name__ == "__main__":
    job_id = train_sft()
    print(f"Job ID: {job_id}")

Pipeline de Entrenamiento DPO

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_dpo():
    params = LLMTrainingParams(
        model="meta-llama/Llama-3.2-1B",
        data_path="./preferences.jsonl",
        project_name="llama-dpo",
        trainer="dpo",
        dpo_beta=0.1,
        max_prompt_length=128,     # Default: 128
        max_completion_length=None,  # Default: None
        epochs=1,
        batch_size=2,
        lr=5e-6,
        peft=True,
        lora_r=16,
    )

    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

Próximos Pasos