Pular para o conteúdo principal

Referência do SDK Python

Guia abrangente da API Python do AITraining.

Instalação

pip install aitraining torch torchvision torchaudio
Nome do Pacote vs Import: Instale com pip install aitraining, mas importe com from autotrain import ...

Treinamento LLM

LLMTrainingParams

from autotrain.trainers.clm.params import LLMTrainingParams

params = LLMTrainingParams(
    # Required
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Training method
    trainer="sft",  # sft, dpo, orpo, ppo, reward, distillation

    # Training settings
    epochs=3,
    batch_size=2,
    lr=2e-5,
    gradient_accumulation=4,
    mixed_precision="bf16",

    # LoRA
    peft=True,
    lora_r=16,
    lora_alpha=32,
    lora_dropout=0.05,

    # Data processing
    text_column="text",
    block_size=2048,
    add_eos_token=True,

    # Logging
    log="wandb",
    logging_steps=-1,  # Default: -1 (auto)
)

Parâmetros Principais

ParâmetroTipoDescrição
modelstrNome ou caminho do modelo
data_pathstrCaminho para os dados de treinamento
project_namestrDiretório de saída
trainerstrMétodo de treinamento
epochsintNúmero de épocas
batch_sizeintTamanho do lote
lrfloatTaxa de aprendizado
peftboolHabilitar LoRA
lora_rintRank do LoRA
lora_alphaintAlpha do LoRA

Classificação de Texto

TextClassificationParams

from autotrain.trainers.text_classification.params import TextClassificationParams

params = TextClassificationParams(
    model="bert-base-uncased",
    data_path="./reviews.csv",
    project_name="sentiment",
    text_column="text",
    target_column="label",
    epochs=5,
    batch_size=16,
    lr=2e-5,
)

Classificação de Imagem

ImageClassificationParams

from autotrain.trainers.image_classification.params import ImageClassificationParams

params = ImageClassificationParams(
    model="google/vit-base-patch16-224",
    data_path="./images/",
    project_name="classifier",
    image_column="image",
    target_column="label",
    epochs=10,
    batch_size=32,
)

Execução do Projeto

AutoTrainProject

from autotrain.project import AutoTrainProject

# Create and run project
project = AutoTrainProject(
    params=params,
    backend="local",  # "local" or "spaces"
    process=True      # Start training immediately
)

job_id = project.create()

Opções de Backend

BackendDescrição
localExecutar na máquina local
spaces-*Executar no Hugging Face Spaces (ex.: spaces-a10g-large, spaces-t4-medium)
ep-*Hugging Face Endpoints
ngc-*NVIDIA NGC
nvcf-*NVIDIA Cloud Functions

Inferência

Usando Completers

from autotrain.generation import CompletionConfig, create_completer

# Configure generation
config = CompletionConfig(
    max_new_tokens=256,
    temperature=0.7,
    top_p=0.95,
    top_k=50,
)

# Create completer (first param is "model", not "model_path")
completer = create_completer(
    model="./my-trained-model",
    completer_type="message",
    config=config
)

# Generate (returns MessageCompletionResult)
result = completer.chat("Hello, how are you?")
print(result.content)  # Access the text content

Usando Transformers Diretamente

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model
model = AutoModelForCausalLM.from_pretrained("./my-model")
tokenizer = AutoTokenizer.from_pretrained("./my-model")

# Generate
inputs = tokenizer("Hello!", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))

Manipulação de Datasets

AutoTrainDataset

from autotrain.dataset import AutoTrainDataset

dataset = AutoTrainDataset(
    train_data=["train.csv"],
    task="text_classification",
    token="hf_...",
    project_name="my-project",
    username="my-username",
    column_mapping={
        "text": "review_text",
        "label": "sentiment"
    },
)

# Prepare dataset
data_path = dataset.prepare()

Arquivos de Configuração

Carregando de YAML

from autotrain.parser import AutoTrainConfigParser

# Parse config file
parser = AutoTrainConfigParser("config.yaml")

# Run training
parser.run()

Tratamento de Erros

from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams

try:
    params = LLMTrainingParams(
        model="google/gemma-3-270m",
        data_path="./data.jsonl",
        project_name="my-model",
    )

    project = AutoTrainProject(params=params, backend="local", process=True)
    job_id = project.create()

except ValueError as e:
    print(f"Configuration error: {e}")
except RuntimeError as e:
    print(f"Training error: {e}")

Exemplos Completos

Pipeline de Treinamento SFT

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_sft():
    params = LLMTrainingParams(
        model="google/gemma-3-270m",
        data_path="./conversations.jsonl",
        project_name="gemma-sft",
        trainer="sft",
        epochs=3,
        batch_size=2,
        gradient_accumulation=8,
        lr=2e-5,
        peft=True,
        lora_r=16,
        lora_alpha=32,
        log="wandb",
    )

    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

if __name__ == "__main__":
    job_id = train_sft()
    print(f"Job ID: {job_id}")

Pipeline de Treinamento DPO

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_dpo():
    params = LLMTrainingParams(
        model="meta-llama/Llama-3.2-1B",
        data_path="./preferences.jsonl",
        project_name="llama-dpo",
        trainer="dpo",
        dpo_beta=0.1,
        max_prompt_length=128,     # Default: 128
        max_completion_length=None,  # Default: None
        epochs=1,
        batch_size=2,
        lr=5e-6,
        peft=True,
        lora_r=16,
    )

    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

Próximos Passos