Saltar al contenido principal

Entrenamiento DPO

La Optimización Directa de Preferencias alinea modelos con preferencias humanas sin modelado de recompensas.

¿Qué es DPO?

DPO (Direct Preference Optimization) es una alternativa más simple a RLHF. En lugar de entrenar un modelo de recompensa separado, DPO optimiza directamente el modelo para preferir respuestas elegidas sobre rechazadas.

Inicio Rápido

aitraining llm --train \
  --model meta-llama/Llama-3.2-1B \
  --data-path ./preferences.jsonl \
  --project-name llama-dpo \
  --trainer dpo \
  --prompt-text-column prompt \
  --text-column chosen \
  --rejected-text-column rejected \
  --dpo-beta 0.1 \
  --peft
DPO requiere --prompt-text-column y --rejected-text-column. El --text-column tiene por defecto "text", así que solo especifícalo si tu columna elegida tiene un nombre diferente.

Python API

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./preferences.jsonl",
    project_name="llama-dpo",

    trainer="dpo",
    prompt_text_column="prompt",
    text_column="chosen",
    rejected_text_column="rejected",
    dpo_beta=0.1,
    max_prompt_length=128,  # Default: 128
    max_completion_length=None,  # Default: None

    epochs=1,
    batch_size=2,
    gradient_accumulation=4,
    lr=5e-6,

    peft=True,
    lora_r=16,
    lora_alpha=32,
)

project = AutoTrainProject(params=params, backend="local", process=True)
project.create()

Formato de Datos

DPO requiere pares de preferencia: un prompt con respuestas elegida y rechazada.
{
  "prompt": "What is the capital of France?",
  "chosen": "The capital of France is Paris.",
  "rejected": "France's capital is London."
}

Múltiples Turnos

{
  "prompt": [
    {"role": "user", "content": "What is AI?"},
    {"role": "assistant", "content": "AI is artificial intelligence."},
    {"role": "user", "content": "Give me an example."}
  ],
  "chosen": "A common example is ChatGPT, which uses AI to understand and generate text.",
  "rejected": "idk lol"
}

Parámetros

ParámetroDescripciónPor Defecto
trainerEstablecer como "dpo"Requerido
dpo_betaCoeficiente de penalización KL0.1
max_prompt_lengthMáximo de tokens para el prompt128
max_completion_lengthMáximo de tokens para la respuestaNone
model_refModelo de referencia (opcional)None (usa modelo base)

Beta

El parámetro beta controla cuánto puede desviarse el modelo de la referencia:
  • 0.01-0.05: Optimización agresiva (puede sobreajustar)
  • 0.1: Estándar (recomendado)
  • 0.5-1.0: Conservador (permanece cerca de la referencia)
# Entrenamiento conservador
params = LLMTrainingParams(
    ...
    trainer="dpo",
    dpo_beta=0.5,  # Mayor = más conservador
)

Modelo de Referencia

Cuando model_ref es None (el defecto), DPO usa el modelo inicial como referencia. Puedes especificar uno diferente:
params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",  # Modelo a entrenar
    model_ref="meta-llama/Llama-3.2-1B-base",  # Modelo de referencia
    ...
    trainer="dpo",
)

Consejos de Entrenamiento

Usa LoRA

DPO funciona bien con LoRA:
params = LLMTrainingParams(
    ...
    trainer="dpo",
    peft=True,
    lora_r=16,
    lora_alpha=32,
    lora_dropout=0.05,
)

Tasa de Aprendizaje Más Baja

DPO es sensible a la tasa de aprendizaje:
params = LLMTrainingParams(
    ...
    trainer="dpo",
    lr=5e-7,  # Mucho más baja que SFT
)

Menos Épocas

DPO típicamente necesita menos épocas:
params = LLMTrainingParams(
    ...
    trainer="dpo",
    epochs=1,  # A menudo 1-3 épocas es suficiente
)

Ejemplo: Asistente Útil

Crear un asistente más útil:
params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./helpfulness_prefs.jsonl",
    project_name="helpful-assistant",

    trainer="dpo",
    dpo_beta=0.1,
    max_prompt_length=1024,
    max_completion_length=512,

    epochs=1,
    batch_size=2,
    gradient_accumulation=8,
    lr=1e-6,

    peft=True,
    lora_r=32,
    lora_alpha=64,

    log="wandb",
)

DPO vs ORPO

AspectoDPOORPO
Modelo de referenciaRequeridoNo requerido
Uso de memoriaMayorMenor
Velocidad de entrenamientoMás lentoMás rápido
Caso de usoAlineación finaSFT + alineación combinado

Recolectando Datos de Preferencia

Anotación Humana

  1. Genera múltiples respuestas por prompt
  2. Ten anotadores clasificando respuestas
  3. Crea pares elegido/rechazado

LLM como Juez

def create_preference_pairs(prompts, model_responses):
    """Usa GPT-4 para juzgar qué respuesta es mejor."""
    # ... generar juicios
    return {"prompt": p, "chosen": better, "rejected": worse}

Próximos Pasos