Pular para o conteúdo principal

API Python

O AITraining fornece uma API Python para acesso programático a toda a funcionalidade de treinamento.

Instalação

pip install aitraining torch

Início Rápido

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

# Configure training
params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",
    trainer="sft",
    epochs=3,
    batch_size=4,
    lr=2e-5,
    peft=True,
    lora_r=16,
)

# Start training
project = AutoTrainProject(params=params, backend="local", process=True)
job_id = project.create()
print(f"Training started: {job_id}")

Estrutura da API

Parâmetros de Treinamento

Cada tipo de tarefa tem sua própria classe de parâmetros:
TarefaClasse de Parâmetros
Treinamento de LLMLLMTrainingParams
Classificação de TextoTextClassificationParams
Classificação de ImagemImageClassificationParams
Classificação de TokenTokenClassificationParams
Seq2SeqSeq2SeqParams
TabularTabularParams
Detecção de ObjetosObjectDetectionParams
VLMVLMTrainingParams

Execução do Projeto

from autotrain.project import AutoTrainProject

# Create project
project = AutoTrainProject(
    params=params,
    backend="local",  # or "spaces"
    process=True      # Start immediately
)

# Run training
job_id = project.create()

Exemplo: Script Completo de Treinamento

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_model():
    # Configure parameters
    params = LLMTrainingParams(
        # Model
        model="meta-llama/Llama-3.2-1B",
        project_name="llama-sft",

        # Data
        data_path="./conversations.jsonl",
        train_split="train",
        text_column="text",
        block_size=2048,

        # Training
        trainer="sft",
        epochs=3,
        batch_size=2,
        gradient_accumulation=4,
        lr=2e-5,
        mixed_precision="bf16",

        # LoRA
        peft=True,
        lora_r=16,
        lora_alpha=32,
        lora_dropout=0.05,

        # Logging
        log="wandb",
        logging_steps=10,
    )

    # Start training
    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

if __name__ == "__main__":
    job_id = train_model()
    print(f"Training complete: {job_id}")

Módulos Principais

MóduloDescrição
autotrain.projectExecução do projeto
autotrain.trainers.clm.paramsParâmetros de LLM
autotrain.trainers.text_classification.paramsClassificação de texto
autotrain.datasetManipulação de datasets
autotrain.generationUtilitários de inferência

Próximos Passos