Referência do SDK Python
Guia abrangente da API Python do AITraining.
Instalação
pip install aitraining torch torchvision torchaudio
Nome do Pacote vs Import: Instale com pip install aitraining, mas importe com from autotrain import ...
Treinamento LLM
LLMTrainingParams
from autotrain.trainers.clm.params import LLMTrainingParams
params = LLMTrainingParams(
# Required
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Training method
trainer="sft", # sft, dpo, orpo, ppo, reward, distillation
# Training settings
epochs=3,
batch_size=2,
lr=2e-5,
gradient_accumulation=4,
mixed_precision="bf16",
# LoRA
peft=True,
lora_r=16,
lora_alpha=32,
lora_dropout=0.05,
# Data processing
text_column="text",
block_size=2048,
add_eos_token=True,
save_processed_data="auto", # auto, local, hub, both, none
# Logging
log="wandb",
logging_steps=-1, # Default: -1 (auto)
# Hyperparameter sweep (optional)
# use_sweep=True,
# sweep_backend="optuna",
# sweep_n_trials=20,
# sweep_params='{"lr": {"type": "loguniform", "low": 1e-5, "high": 1e-3}}',
)
Parâmetros Principais
| Parâmetro | Tipo | Descrição |
|---|
model | str | Nome ou caminho do modelo |
data_path | str | Caminho para os dados de treinamento |
project_name | str | Diretório de saída |
trainer | str | Método de treinamento |
epochs | int | Número de épocas |
batch_size | int | Tamanho do lote |
lr | float | Taxa de aprendizado |
peft | bool | Habilitar LoRA |
lora_r | int | Rank do LoRA |
lora_alpha | int | Alpha do LoRA |
save_processed_data | str | Salvar dados processados: auto, local, hub, both, none |
Parâmetros de Sweep de Hiperparâmetros
| Parâmetro | Tipo | Descrição |
|---|
use_sweep | bool | Habilitar sweep de hiperparâmetros |
sweep_backend | str | Backend: optuna, grid, random |
sweep_n_trials | int | Número de tentativas |
sweep_metric | str | Métrica a otimizar |
sweep_direction | str | minimize ou maximize |
sweep_params | str | Espaço de busca personalizado (string JSON) |
wandb_sweep | bool | Habilitar dashboard nativo de sweeps W&B |
wandb_sweep_project | str | Projeto W&B para sweep |
wandb_sweep_entity | str | Entidade W&B (equipe/usuário) |
wandb_sweep_id | str | ID de sweep existente para continuar |
Formatos de lista e dicionário são suportados:
import json
# Formato dict (recomendado)
sweep_params = json.dumps({
"lr": {"type": "loguniform", "low": 1e-5, "high": 1e-3},
"batch_size": {"type": "categorical", "values": [2, 4, 8]},
})
# Formato lista (abreviado para categóricos)
sweep_params = json.dumps({
"batch_size": [2, 4, 8],
})
Tipos suportados: categorical, loguniform, uniform, int.
Classificação de Texto
TextClassificationParams
from autotrain.trainers.text_classification.params import TextClassificationParams
params = TextClassificationParams(
model="bert-base-uncased",
data_path="./reviews.csv",
project_name="sentiment",
text_column="text",
target_column="label",
epochs=5,
batch_size=16,
lr=2e-5,
)
Classificação de Imagem
ImageClassificationParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
params = ImageClassificationParams(
model="google/vit-base-patch16-224",
data_path="./images/",
project_name="classifier",
image_column="image",
target_column="label",
epochs=10,
batch_size=32,
)
Execução do Projeto
AutoTrainProject
from autotrain.project import AutoTrainProject
# Create and run project
project = AutoTrainProject(
params=params,
backend="local", # "local" or "spaces"
process=True # Start training immediately
)
job_id = project.create()
Opções de Backend
| Backend | Descrição |
|---|
local | Executar na máquina local |
spaces-* | Executar no Hugging Face Spaces (ex.: spaces-a10g-large, spaces-t4-medium) |
ep-* | Hugging Face Endpoints |
ngc-* | NVIDIA NGC |
nvcf-* | NVIDIA Cloud Functions |
Inferência
Usando Completers
from autotrain.generation import CompletionConfig, create_completer
# Configure generation
config = CompletionConfig(
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
top_k=50,
)
# Create completer (first param is "model", not "model_path")
completer = create_completer(
model="./my-trained-model",
completer_type="message",
config=config
)
# Generate (returns MessageCompletionResult)
result = completer.chat("Hello, how are you?")
print(result.content) # Access the text content
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model
model = AutoModelForCausalLM.from_pretrained("./my-model")
tokenizer = AutoTokenizer.from_pretrained("./my-model")
# Generate
inputs = tokenizer("Hello!", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
Manipulação de Datasets
AutoTrainDataset
from autotrain.dataset import AutoTrainDataset
dataset = AutoTrainDataset(
train_data=["train.csv"],
task="text_classification",
token="hf_...",
project_name="my-project",
username="my-username",
column_mapping={
"text": "review_text",
"label": "sentiment"
},
)
# Prepare dataset
data_path = dataset.prepare()
Arquivos de Configuração
Carregando de YAML
from autotrain.parser import AutoTrainConfigParser
# Parse config file
parser = AutoTrainConfigParser("config.yaml")
# Run training
parser.run()
Tratamento de Erros
from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams
try:
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
)
project = AutoTrainProject(params=params, backend="local", process=True)
job_id = project.create()
except ValueError as e:
print(f"Configuration error: {e}")
except RuntimeError as e:
print(f"Training error: {e}")
Exemplos Completos
Pipeline de Treinamento SFT
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
def train_sft():
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./conversations.jsonl",
project_name="gemma-sft",
trainer="sft",
epochs=3,
batch_size=2,
gradient_accumulation=8,
lr=2e-5,
peft=True,
lora_r=16,
lora_alpha=32,
log="wandb",
)
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
return project.create()
if __name__ == "__main__":
job_id = train_sft()
print(f"Job ID: {job_id}")
Pipeline de Treinamento DPO
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
def train_dpo():
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./preferences.jsonl",
project_name="llama-dpo",
trainer="dpo",
dpo_beta=0.1,
max_prompt_length=128, # Default: 128
max_completion_length=None, # Default: None
epochs=1,
batch_size=2,
lr=5e-6,
peft=True,
lora_r=16,
)
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
return project.create()
Próximos Passos