Pular para o conteúdo principal

Treinamento DPO

A Otimização Direta de Preferências alinha modelos com preferências humanas sem modelagem de recompensa.

O que é DPO?

DPO (Direct Preference Optimization) é uma alternativa mais simples ao RLHF. Em vez de treinar um modelo de recompensa separado, o DPO otimiza diretamente o modelo para preferir respostas escolhidas sobre rejeitadas.

Início Rápido

aitraining llm --train \
  --model meta-llama/Llama-3.2-1B \
  --data-path ./preferences.jsonl \
  --project-name llama-dpo \
  --trainer dpo \
  --prompt-text-column prompt \
  --text-column chosen \
  --rejected-text-column rejected \
  --dpo-beta 0.1 \
  --peft
DPO requer --prompt-text-column e --rejected-text-column. O --text-column tem padrão "text", então especifique apenas se sua coluna escolhida tiver um nome diferente.

Python API

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./preferences.jsonl",
    project_name="llama-dpo",

    trainer="dpo",
    prompt_text_column="prompt",
    text_column="chosen",
    rejected_text_column="rejected",
    dpo_beta=0.1,
    max_prompt_length=128,  # Default: 128
    max_completion_length=None,  # Default: None

    epochs=1,
    batch_size=2,
    gradient_accumulation=4,
    lr=5e-6,

    peft=True,
    lora_r=16,
    lora_alpha=32,
)

project = AutoTrainProject(params=params, backend="local", process=True)
project.create()

Formato dos Dados

DPO requer pares de preferência: um prompt com respostas escolhida e rejeitada.
{
  "prompt": "What is the capital of France?",
  "chosen": "The capital of France is Paris.",
  "rejected": "France's capital is London."
}

Múltiplas Voltas

{
  "prompt": [
    {"role": "user", "content": "What is AI?"},
    {"role": "assistant", "content": "AI is artificial intelligence."},
    {"role": "user", "content": "Give me an example."}
  ],
  "chosen": "A common example is ChatGPT, which uses AI to understand and generate text.",
  "rejected": "idk lol"
}

Parâmetros

ParâmetroDescriçãoPadrão
trainerDefinir como "dpo"Obrigatório
dpo_betaCoeficiente de penalidade KL0.1
max_prompt_lengthMáximo de tokens para o prompt128
max_completion_lengthMáximo de tokens para a respostaNone
model_refModelo de referência (opcional)None (usa modelo base)

Beta

O parâmetro beta controla o quanto o modelo pode se desviar da referência:
  • 0.01-0.05: Otimização agressiva (pode sobreajustar)
  • 0.1: Padrão (recomendado)
  • 0.5-1.0: Conservador (permanece próximo à referência)
# Treinamento conservador
params = LLMTrainingParams(
    ...
    trainer="dpo",
    dpo_beta=0.5,  # Maior = mais conservador
)

Modelo de Referência

Quando model_ref é None (o padrão), o DPO usa o modelo inicial como referência. Você pode especificar um diferente:
params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",  # Modelo a treinar
    model_ref="meta-llama/Llama-3.2-1B-base",  # Modelo de referência
    ...
    trainer="dpo",
)

Dicas de Treinamento

Use LoRA

DPO funciona bem com LoRA:
params = LLMTrainingParams(
    ...
    trainer="dpo",
    peft=True,
    lora_r=16,
    lora_alpha=32,
    lora_dropout=0.05,
)

Taxa de Aprendizado Mais Baixa

DPO é sensível à taxa de aprendizado:
params = LLMTrainingParams(
    ...
    trainer="dpo",
    lr=5e-7,  # Muito menor que SFT
)

Menos Épocas

DPO tipicamente precisa de menos épocas:
params = LLMTrainingParams(
    ...
    trainer="dpo",
    epochs=1,  # Frequentemente 1-3 épocas é suficiente
)

Exemplo: Assistente Útil

Criar um assistente mais útil:
params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./helpfulness_prefs.jsonl",
    project_name="helpful-assistant",

    trainer="dpo",
    dpo_beta=0.1,
    max_prompt_length=1024,
    max_completion_length=512,

    epochs=1,
    batch_size=2,
    gradient_accumulation=8,
    lr=1e-6,

    peft=True,
    lora_r=32,
    lora_alpha=64,

    log="wandb",
)

DPO vs ORPO

AspectoDPOORPO
Modelo de referênciaObrigatórioNão obrigatório
Uso de memóriaMaiorMenor
Velocidade de treinamentoMais lentoMais rápido
Caso de usoAlinhamento finoSFT + alinhamento combinado

Coletando Dados de Preferência

Anotação Humana

  1. Gere múltiplas respostas por prompt
  2. Tenha anotadores classificando respostas
  3. Crie pares escolhido/rejeitado

LLM como Juiz

def create_preference_pairs(prompts, model_responses):
    """Use GPT-4 para julgar qual resposta é melhor."""
    # ... gerar julgamentos
    return {"prompt": p, "chosen": better, "rejected": worse}

Próximos Passos