API de Treinamento LLM
Referência completa da API para treinamento de LLM.LLMTrainingParams
A classe principal de configuração para treinamento de LLM.Copiar
from autotrain.trainers.clm.params import LLMTrainingParams
Parâmetros Básicos
Copiar
params = LLMTrainingParams(
# Core parameters (always specify these)
model="google/gemma-3-270m", # Default: "google/gemma-3-270m"
data_path="./data.jsonl", # Default: "data"
project_name="my-model", # Default: "project-name"
# Data splits
train_split="train", # Default: "train"
valid_split=None, # Default: None
max_samples=None, # Default: None (use all)
)
Seleção de Trainer
Copiar
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
trainer="sft", # Default: "default" (pretraining). Options: sft, dpo, orpo, ppo, grpo, reward
)
Hiperparâmetros de Treinamento
Copiar
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Core hyperparameters (showing defaults)
epochs=1, # Default: 1
batch_size=2, # Default: 2
lr=3e-5, # Default: 3e-5
warmup_ratio=0.1, # Default: 0.1
gradient_accumulation=4, # Default: 4
weight_decay=0.0, # Default: 0.0
max_grad_norm=1.0, # Default: 1.0
# Precision
mixed_precision=None, # Default: None (options: bf16, fp16, None)
# Optimization
optimizer="adamw_torch", # Default: adamw_torch
scheduler="linear", # Default: linear
seed=42, # Default: 42
)
Configuração PEFT/LoRA
Copiar
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./data.jsonl",
project_name="my-model",
# Enable LoRA (default: False)
peft=True,
lora_r=16, # Default: 16
lora_alpha=32, # Default: 32
lora_dropout=0.05, # Default: 0.05
target_modules="all-linear", # Default: all-linear
# Quantization (optional)
quantization="int4", # Options: int4, int8, or None (default: None)
# Merge after training (default is True - LoRA merged automatically)
merge_adapter=True,
)
Processamento de Dados
Copiar
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Text processing
text_column="text",
block_size=-1, # Default: -1 (model default)
model_max_length=2048, # Auto-detectado da config do modelo (veja nota)
add_eos_token=True, # Default: True
padding="right", # Default: "right"
# Chat format
chat_template=None, # Auto-detect or specify
apply_chat_template=True, # Default: True
# Efficiency
packing=None, # Default: None (set True to enable)
use_flash_attention_2=False, # Default: False
attn_implementation=None, # Default: None
)
Auto-detecção de model_max_length: Este parâmetro agora é detectado automaticamente da configuração do modelo. Por exemplo, Gemma 2 (8192 tokens) e Gemma 3 (32K-128K tokens dependendo da variante) usarão automaticamente seus comprimentos de contexto nativos. O padrão 2048 é usado apenas como fallback quando a auto-detecção falha. Defina explicitamente para sobrescrever.
Parâmetros DPO
Copiar
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./preferences.jsonl",
project_name="my-model",
trainer="dpo",
# DPO-specific
dpo_beta=0.1, # Default: 0.1
max_prompt_length=128, # Default: 128
max_completion_length=None, # Default: None
# Reference model (optional)
model_ref=None, # Uses same as model if None
# Data columns (required for DPO)
prompt_text_column="prompt",
text_column="chosen",
rejected_text_column="rejected",
)
Parâmetros ORPO
Copiar
params = LLMTrainingParams(
model="google/gemma-2-2b",
data_path="./preferences.jsonl",
project_name="my-model",
trainer="orpo",
# ORPO-specific
dpo_beta=0.1, # Default: 0.1
max_prompt_length=128, # Default: 128
max_completion_length=None, # Default: None
# Data columns (required for ORPO)
prompt_text_column="prompt",
text_column="chosen",
rejected_text_column="rejected",
)
Parâmetros GRPO
Copiar
params = LLMTrainingParams(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
project_name="grpo-agent",
trainer="grpo",
# Específicos do GRPO (obrigatórios)
rl_env_module="my_envs.hotel_env", # Caminho do módulo Python para o ambiente
rl_env_class="HotelEnv", # Nome da classe no módulo do ambiente
rl_num_generations=4, # Default: 4 — completações por prompt
# Parâmetros RL compartilhados (usados por PPO e GRPO)
rl_kl_coef=0.1, # Default: 0.1 — Penalidade de divergência KL (beta)
rl_clip_range=0.2, # Default: 0.2 — Faixa de clipping (epsilon)
rl_env_config=None, # Default: None — Configuração JSON para o construtor do ambiente
rl_max_new_tokens=256, # Default: 128 — Máximo de tokens por completação
rl_top_k=50, # Default: 50
rl_top_p=1.0, # Default: 1.0
rl_temperature=1.0, # Default: 1.0
# Aceleração vLLM (opcional)
use_vllm=False, # Default: False — habilitar vLLM para geração mais rápida
vllm_mode="colocate", # Default: "colocate" — ou "server"
vllm_gpu_memory_utilization=0.3, # Default: 0.3 — Fração da memória GPU para vLLM (colocate)
vllm_server_url=None, # Default: None — URL do servidor vLLM (modo server)
vllm_tensor_parallel_size=1, # Default: 1 — GPUs para paralelismo tensorial do vLLM
vllm_server_gpus=1, # Default: 1 — GPUs reservadas para servidor vLLM
)
GRPO não requer
data_path — o dataset é construído pelo método build_dataset() do ambiente. Instale pip install aitraining[vllm] para suporte vLLM.Veja Treinamento GRPO para a interface completa do ambiente (
build_dataset, score_episode, get_tools).Distilação de Conhecimento
Copiar
params = LLMTrainingParams(
model="google/gemma-3-270m", # Student
teacher_model="google/gemma-2-2b", # Teacher
data_path="./prompts.jsonl",
project_name="distilled-model",
use_distillation=True,
distill_temperature=3.0, # Default: 3.0
distill_alpha=0.7, # Default: 0.7
distill_max_teacher_length=512, # Default: 512
)
Logging e Salvamento
Copiar
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Logging
log="wandb", # wandb, tensorboard, or None (default: wandb)
logging_steps=-1, # Default: -1 (auto)
wandb_visualizer=True, # Terminal visualizer
wandb_token=None, # W&B API token (optional)
# Checkpointing
save_strategy="steps", # steps or epoch (default: epoch)
save_steps=500,
save_total_limit=1, # Default: 1
eval_strategy="steps",
)
Integração com Hub
Copiar
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Push to Hub
push_to_hub=True,
username="your-username",
token="hf_...",
)
Executando o Treinamento
Copiar
from autotrain.project import AutoTrainProject
# Create and run project
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
job_id = project.create()
Exemplo Completo
Copiar
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
# Full configuration
params = LLMTrainingParams(
# Model
model="meta-llama/Llama-3.2-1B",
project_name="llama-production",
# Data
data_path="./conversations.jsonl",
train_split="train",
valid_split="validation",
text_column="text",
block_size=2048,
# Training
trainer="sft",
epochs=3,
batch_size=2,
gradient_accumulation=8,
lr=2e-5,
warmup_ratio=0.1,
mixed_precision="bf16",
# LoRA
peft=True,
lora_r=32,
lora_alpha=64,
lora_dropout=0.05,
# Optimization
use_flash_attention_2=True,
packing=True,
auto_find_batch_size=True,
unsloth=False, # Use Unsloth for faster training
# Distribution (for multi-GPU)
distributed_backend=None, # None for auto (DDP), or "deepspeed"
# Logging
log="wandb",
logging_steps=-1,
save_strategy="steps",
save_steps=500,
save_total_limit=1,
# Hub
push_to_hub=True,
username="my-username",
token="hf_...",
)
# Run training
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
job_id = project.create()