Referência do SDK Python
Guia abrangente da API Python do AITraining.
Instalação
pip install aitraining torch torchvision torchaudio
Nome do Pacote vs Import: Instale com pip install aitraining, mas importe com from autotrain import ...
Treinamento LLM
LLMTrainingParams
from autotrain.trainers.clm.params import LLMTrainingParams
params = LLMTrainingParams(
# Required
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Training method
trainer="sft", # sft, dpo, orpo, ppo, reward, distillation
# Training settings
epochs=3,
batch_size=2,
lr=2e-5,
gradient_accumulation=4,
mixed_precision="bf16",
# LoRA
peft=True,
lora_r=16,
lora_alpha=32,
lora_dropout=0.05,
# Data processing
text_column="text",
block_size=2048,
add_eos_token=True,
# Logging
log="wandb",
logging_steps=-1, # Default: -1 (auto)
)
Parâmetros Principais
| Parâmetro | Tipo | Descrição |
|---|
model | str | Nome ou caminho do modelo |
data_path | str | Caminho para os dados de treinamento |
project_name | str | Diretório de saída |
trainer | str | Método de treinamento |
epochs | int | Número de épocas |
batch_size | int | Tamanho do lote |
lr | float | Taxa de aprendizado |
peft | bool | Habilitar LoRA |
lora_r | int | Rank do LoRA |
lora_alpha | int | Alpha do LoRA |
Classificação de Texto
TextClassificationParams
from autotrain.trainers.text_classification.params import TextClassificationParams
params = TextClassificationParams(
model="bert-base-uncased",
data_path="./reviews.csv",
project_name="sentiment",
text_column="text",
target_column="label",
epochs=5,
batch_size=16,
lr=2e-5,
)
Classificação de Imagem
ImageClassificationParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
params = ImageClassificationParams(
model="google/vit-base-patch16-224",
data_path="./images/",
project_name="classifier",
image_column="image",
target_column="label",
epochs=10,
batch_size=32,
)
Execução do Projeto
AutoTrainProject
from autotrain.project import AutoTrainProject
# Create and run project
project = AutoTrainProject(
params=params,
backend="local", # "local" or "spaces"
process=True # Start training immediately
)
job_id = project.create()
Opções de Backend
| Backend | Descrição |
|---|
local | Executar na máquina local |
spaces-* | Executar no Hugging Face Spaces (ex.: spaces-a10g-large, spaces-t4-medium) |
ep-* | Hugging Face Endpoints |
ngc-* | NVIDIA NGC |
nvcf-* | NVIDIA Cloud Functions |
Inferência
Usando Completers
from autotrain.generation import CompletionConfig, create_completer
# Configure generation
config = CompletionConfig(
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
top_k=50,
)
# Create completer (first param is "model", not "model_path")
completer = create_completer(
model="./my-trained-model",
completer_type="message",
config=config
)
# Generate (returns MessageCompletionResult)
result = completer.chat("Hello, how are you?")
print(result.content) # Access the text content
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model
model = AutoModelForCausalLM.from_pretrained("./my-model")
tokenizer = AutoTokenizer.from_pretrained("./my-model")
# Generate
inputs = tokenizer("Hello!", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
Manipulação de Datasets
AutoTrainDataset
from autotrain.dataset import AutoTrainDataset
dataset = AutoTrainDataset(
train_data=["train.csv"],
task="text_classification",
token="hf_...",
project_name="my-project",
username="my-username",
column_mapping={
"text": "review_text",
"label": "sentiment"
},
)
# Prepare dataset
data_path = dataset.prepare()
Arquivos de Configuração
Carregando de YAML
from autotrain.parser import AutoTrainConfigParser
# Parse config file
parser = AutoTrainConfigParser("config.yaml")
# Run training
parser.run()
Tratamento de Erros
from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams
try:
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
)
project = AutoTrainProject(params=params, backend="local", process=True)
job_id = project.create()
except ValueError as e:
print(f"Configuration error: {e}")
except RuntimeError as e:
print(f"Training error: {e}")
Exemplos Completos
Pipeline de Treinamento SFT
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
def train_sft():
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./conversations.jsonl",
project_name="gemma-sft",
trainer="sft",
epochs=3,
batch_size=2,
gradient_accumulation=8,
lr=2e-5,
peft=True,
lora_r=16,
lora_alpha=32,
log="wandb",
)
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
return project.create()
if __name__ == "__main__":
job_id = train_sft()
print(f"Job ID: {job_id}")
Pipeline de Treinamento DPO
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
def train_dpo():
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./preferences.jsonl",
project_name="llama-dpo",
trainer="dpo",
dpo_beta=0.1,
max_prompt_length=128, # Default: 128
max_completion_length=None, # Default: None
epochs=1,
batch_size=2,
lr=5e-6,
peft=True,
lora_r=16,
)
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
return project.create()
Próximos Passos