Saltar al contenido principal

API de Python

AITraining proporciona una API de Python para acceso programático a toda la funcionalidad de entrenamiento.

Instalación

pip install aitraining torch

Inicio Rápido

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

# Configure training
params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",
    trainer="sft",
    epochs=3,
    batch_size=4,
    lr=2e-5,
    peft=True,
    lora_r=16,
)

# Start training
project = AutoTrainProject(params=params, backend="local", process=True)
job_id = project.create()
print(f"Training started: {job_id}")

Estructura de la API

Parámetros de Entrenamiento

Cada tipo de tarea tiene su propia clase de parámetros:
TareaClase de Parámetros
Entrenamiento LLMLLMTrainingParams
Clasificación de TextoTextClassificationParams
Clasificación de ImagenImageClassificationParams
Clasificación de TokensTokenClassificationParams
Seq2SeqSeq2SeqParams
TabularTabularParams
Detección de ObjetosObjectDetectionParams
VLMVLMTrainingParams

Ejecución del Proyecto

from autotrain.project import AutoTrainProject

# Create project
project = AutoTrainProject(
    params=params,
    backend="local",  # or "spaces"
    process=True      # Start immediately
)

# Run training
job_id = project.create()

Ejemplo: Script Completo de Entrenamiento

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_model():
    # Configure parameters
    params = LLMTrainingParams(
        # Model
        model="meta-llama/Llama-3.2-1B",
        project_name="llama-sft",

        # Data
        data_path="./conversations.jsonl",
        train_split="train",
        text_column="text",
        block_size=2048,

        # Training
        trainer="sft",
        epochs=3,
        batch_size=2,
        gradient_accumulation=4,
        lr=2e-5,
        mixed_precision="bf16",

        # LoRA
        peft=True,
        lora_r=16,
        lora_alpha=32,
        lora_dropout=0.05,

        # Logging
        log="wandb",
        logging_steps=10,
    )

    # Start training
    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

if __name__ == "__main__":
    job_id = train_model()
    print(f"Training complete: {job_id}")

Módulos Principales

MóduloDescripción
autotrain.projectEjecución del proyecto
autotrain.trainers.clm.paramsParámetros LLM
autotrain.trainers.text_classification.paramsClasificación de texto
autotrain.datasetManejo de datasets
autotrain.generationUtilidades de inferencia

Próximos Pasos