Treinamento DPO
A Otimização Direta de Preferências alinha modelos com preferências humanas sem modelagem de recompensa.
O que é DPO?
DPO (Direct Preference Optimization) é uma alternativa mais simples ao RLHF. Em vez de treinar um modelo de recompensa separado, o DPO otimiza diretamente o modelo para preferir respostas escolhidas sobre rejeitadas.
Início Rápido
aitraining llm --train \
--model meta-llama/Llama-3.2-1B \
--data-path ./preferences.jsonl \
--project-name llama-dpo \
--trainer dpo \
--prompt-text-column prompt \
--text-column chosen \
--rejected-text-column rejected \
--dpo-beta 0.1 \
--peft
DPO requer --prompt-text-column e --rejected-text-column. O --text-column tem padrão "text", então especifique apenas se sua coluna escolhida tiver um nome diferente.
Python API
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./preferences.jsonl",
project_name="llama-dpo",
trainer="dpo",
prompt_text_column="prompt",
text_column="chosen",
rejected_text_column="rejected",
dpo_beta=0.1,
max_prompt_length=128, # Default: 128
max_completion_length=None, # Default: None
epochs=1,
batch_size=2,
gradient_accumulation=4,
lr=5e-6,
peft=True,
lora_r=16,
lora_alpha=32,
)
project = AutoTrainProject(params=params, backend="local", process=True)
project.create()
DPO requer pares de preferência: um prompt com respostas escolhida e rejeitada.
{
"prompt": "What is the capital of France?",
"chosen": "The capital of France is Paris.",
"rejected": "France's capital is London."
}
Múltiplas Voltas
{
"prompt": [
{"role": "user", "content": "What is AI?"},
{"role": "assistant", "content": "AI is artificial intelligence."},
{"role": "user", "content": "Give me an example."}
],
"chosen": "A common example is ChatGPT, which uses AI to understand and generate text.",
"rejected": "idk lol"
}
Parâmetros
| Parâmetro | Descrição | Padrão |
|---|
trainer | Definir como "dpo" | Obrigatório |
dpo_beta | Coeficiente de penalidade KL | 0.1 |
max_prompt_length | Máximo de tokens para o prompt | 128 |
max_completion_length | Máximo de tokens para a resposta | None |
model_ref | Modelo de referência (opcional) | None (usa modelo base) |
Beta
O parâmetro beta controla o quanto o modelo pode se desviar da referência:
0.01-0.05: Otimização agressiva (pode sobreajustar)
0.1: Padrão (recomendado)
0.5-1.0: Conservador (permanece próximo à referência)
# Treinamento conservador
params = LLMTrainingParams(
...
trainer="dpo",
dpo_beta=0.5, # Maior = mais conservador
)
Modelo de Referência
Quando model_ref é None (o padrão), o DPO usa o modelo inicial como referência. Você pode especificar um diferente:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B", # Modelo a treinar
model_ref="meta-llama/Llama-3.2-1B-base", # Modelo de referência
...
trainer="dpo",
)
Dicas de Treinamento
Use LoRA
DPO funciona bem com LoRA:
params = LLMTrainingParams(
...
trainer="dpo",
peft=True,
lora_r=16,
lora_alpha=32,
lora_dropout=0.05,
)
Taxa de Aprendizado Mais Baixa
DPO é sensível à taxa de aprendizado:
params = LLMTrainingParams(
...
trainer="dpo",
lr=5e-7, # Muito menor que SFT
)
Menos Épocas
DPO tipicamente precisa de menos épocas:
params = LLMTrainingParams(
...
trainer="dpo",
epochs=1, # Frequentemente 1-3 épocas é suficiente
)
Exemplo: Assistente Útil
Criar um assistente mais útil:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./helpfulness_prefs.jsonl",
project_name="helpful-assistant",
trainer="dpo",
dpo_beta=0.1,
max_prompt_length=1024,
max_completion_length=512,
epochs=1,
batch_size=2,
gradient_accumulation=8,
lr=1e-6,
peft=True,
lora_r=32,
lora_alpha=64,
log="wandb",
)
DPO vs ORPO
| Aspecto | DPO | ORPO |
|---|
| Modelo de referência | Obrigatório | Não obrigatório |
| Uso de memória | Maior | Menor |
| Velocidade de treinamento | Mais lento | Mais rápido |
| Caso de uso | Alinhamento fino | SFT + alinhamento combinado |
Coletando Dados de Preferência
Anotação Humana
- Gere múltiplas respostas por prompt
- Tenha anotadores classificando respostas
- Crie pares escolhido/rejeitado
LLM como Juiz
def create_preference_pairs(prompts, model_responses):
"""Use GPT-4 para julgar qual resposta é melhor."""
# ... gerar julgamentos
return {"prompt": p, "chosen": better, "rejected": worse}
Próximos Passos