Entrenamiento DPO
La Optimización Directa de Preferencias alinea modelos con preferencias humanas sin modelado de recompensas.
¿Qué es DPO?
DPO (Direct Preference Optimization) es una alternativa más simple a RLHF. En lugar de entrenar un modelo de recompensa separado, DPO optimiza directamente el modelo para preferir respuestas elegidas sobre rechazadas.
Inicio Rápido
aitraining llm --train \
--model meta-llama/Llama-3.2-1B \
--data-path ./preferences.jsonl \
--project-name llama-dpo \
--trainer dpo \
--prompt-text-column prompt \
--text-column chosen \
--rejected-text-column rejected \
--dpo-beta 0.1 \
--peft
DPO requiere --prompt-text-column y --rejected-text-column. El --text-column tiene por defecto "text", así que solo especifícalo si tu columna elegida tiene un nombre diferente.
Python API
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./preferences.jsonl",
project_name="llama-dpo",
trainer="dpo",
prompt_text_column="prompt",
text_column="chosen",
rejected_text_column="rejected",
dpo_beta=0.1,
max_prompt_length=128, # Default: 128
max_completion_length=None, # Default: None
epochs=1,
batch_size=2,
gradient_accumulation=4,
lr=5e-6,
peft=True,
lora_r=16,
lora_alpha=32,
)
project = AutoTrainProject(params=params, backend="local", process=True)
project.create()
DPO requiere pares de preferencia: un prompt con respuestas elegida y rechazada.
{
"prompt": "What is the capital of France?",
"chosen": "The capital of France is Paris.",
"rejected": "France's capital is London."
}
Múltiples Turnos
{
"prompt": [
{"role": "user", "content": "What is AI?"},
{"role": "assistant", "content": "AI is artificial intelligence."},
{"role": "user", "content": "Give me an example."}
],
"chosen": "A common example is ChatGPT, which uses AI to understand and generate text.",
"rejected": "idk lol"
}
Parámetros
| Parámetro | Descripción | Por Defecto |
|---|
trainer | Establecer como "dpo" | Requerido |
dpo_beta | Coeficiente de penalización KL | 0.1 |
max_prompt_length | Máximo de tokens para el prompt | 128 |
max_completion_length | Máximo de tokens para la respuesta | None |
model_ref | Modelo de referencia (opcional) | None (usa modelo base) |
Beta
El parámetro beta controla cuánto puede desviarse el modelo de la referencia:
0.01-0.05: Optimización agresiva (puede sobreajustar)
0.1: Estándar (recomendado)
0.5-1.0: Conservador (permanece cerca de la referencia)
# Entrenamiento conservador
params = LLMTrainingParams(
...
trainer="dpo",
dpo_beta=0.5, # Mayor = más conservador
)
Modelo de Referencia
Cuando model_ref es None (el defecto), DPO usa el modelo inicial como referencia. Puedes especificar uno diferente:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B", # Modelo a entrenar
model_ref="meta-llama/Llama-3.2-1B-base", # Modelo de referencia
...
trainer="dpo",
)
Consejos de Entrenamiento
Usa LoRA
DPO funciona bien con LoRA:
params = LLMTrainingParams(
...
trainer="dpo",
peft=True,
lora_r=16,
lora_alpha=32,
lora_dropout=0.05,
)
Tasa de Aprendizaje Más Baja
DPO es sensible a la tasa de aprendizaje:
params = LLMTrainingParams(
...
trainer="dpo",
lr=5e-7, # Mucho más baja que SFT
)
Menos Épocas
DPO típicamente necesita menos épocas:
params = LLMTrainingParams(
...
trainer="dpo",
epochs=1, # A menudo 1-3 épocas es suficiente
)
Ejemplo: Asistente Útil
Crear un asistente más útil:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./helpfulness_prefs.jsonl",
project_name="helpful-assistant",
trainer="dpo",
dpo_beta=0.1,
max_prompt_length=1024,
max_completion_length=512,
epochs=1,
batch_size=2,
gradient_accumulation=8,
lr=1e-6,
peft=True,
lora_r=32,
lora_alpha=64,
log="wandb",
)
DPO vs ORPO
| Aspecto | DPO | ORPO |
|---|
| Modelo de referencia | Requerido | No requerido |
| Uso de memoria | Mayor | Menor |
| Velocidad de entrenamiento | Más lento | Más rápido |
| Caso de uso | Alineación fina | SFT + alineación combinado |
Recolectando Datos de Preferencia
Anotación Humana
- Genera múltiples respuestas por prompt
- Ten anotadores clasificando respuestas
- Crea pares elegido/rechazado
LLM como Juez
def create_preference_pairs(prompts, model_responses):
"""Usa GPT-4 para juzgar qué respuesta es mejor."""
# ... generar juicios
return {"prompt": p, "chosen": better, "rejected": worse}
Próximos Pasos