DPO Training
Direct Preference Optimization aligns models with human preferences without reward modeling.
What is DPO?
DPO (Direct Preference Optimization) is a simpler alternative to RLHF. Instead of training a separate reward model, DPO directly optimizes the model to prefer chosen responses over rejected ones.
Quick Start
aitraining llm --train \
--model meta-llama/Llama-3.2-1B \
--data-path ./preferences.jsonl \
--project-name llama-dpo \
--trainer dpo \
--prompt-text-column prompt \
--text-column chosen \
--rejected-text-column rejected \
--dpo-beta 0.1 \
--peft
DPO requires --prompt-text-column and --rejected-text-column. The --text-column defaults to "text", so only specify it if your chosen column has a different name.
Python API
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./preferences.jsonl",
project_name="llama-dpo",
trainer="dpo",
prompt_text_column="prompt",
text_column="chosen",
rejected_text_column="rejected",
dpo_beta=0.1,
max_prompt_length=128, # Default: 128
max_completion_length=None, # Default: None
epochs=1,
batch_size=2,
gradient_accumulation=4,
lr=5e-6,
peft=True,
lora_r=16,
lora_alpha=32,
)
project = AutoTrainProject(params=params, backend="local", process=True)
project.create()
DPO requires preference pairs: a prompt with chosen and rejected responses.
{
"prompt": "What is the capital of France?",
"chosen": "The capital of France is Paris.",
"rejected": "France's capital is London."
}
Multiple Turns
{
"prompt": [
{"role": "user", "content": "What is AI?"},
{"role": "assistant", "content": "AI is artificial intelligence."},
{"role": "user", "content": "Give me an example."}
],
"chosen": "A common example is ChatGPT, which uses AI to understand and generate text.",
"rejected": "idk lol"
}
Parameters
| Parameter | Description | Default |
|---|
trainer | Set to "dpo" | Required |
dpo_beta | KL penalty coefficient | 0.1 |
max_prompt_length | Max tokens for prompt | 128 |
max_completion_length | Max tokens for response | None |
model_ref | Reference model (optional) | None (uses base model) |
Beta
The beta parameter controls how much the model can deviate from the reference:
0.01-0.05: Aggressive optimization (may overfit)
0.1: Standard (recommended)
0.5-1.0: Conservative (stays close to reference)
# Conservative training
params = LLMTrainingParams(
...
trainer="dpo",
dpo_beta=0.5, # Higher = more conservative
)
Reference Model
When model_ref is None (the default), DPO uses the initial model as the reference. You can specify a different one:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B", # Model to train
model_ref="meta-llama/Llama-3.2-1B-base", # Reference model
...
trainer="dpo",
)
Training Tips
Use LoRA
DPO works well with LoRA:
params = LLMTrainingParams(
...
trainer="dpo",
peft=True,
lora_r=16,
lora_alpha=32,
lora_dropout=0.05,
)
Lower Learning Rate
DPO is sensitive to learning rate:
params = LLMTrainingParams(
...
trainer="dpo",
lr=5e-7, # Much lower than SFT
)
Fewer Epochs
DPO typically needs fewer epochs:
params = LLMTrainingParams(
...
trainer="dpo",
epochs=1, # Often 1-3 epochs is enough
)
Example: Helpful Assistant
Create a more helpful assistant:
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./helpfulness_prefs.jsonl",
project_name="helpful-assistant",
trainer="dpo",
dpo_beta=0.1,
max_prompt_length=1024,
max_completion_length=512,
epochs=1,
batch_size=2,
gradient_accumulation=8,
lr=1e-6,
peft=True,
lora_r=32,
lora_alpha=64,
log="wandb",
)
DPO vs ORPO
| Aspect | DPO | ORPO |
|---|
| Reference model | Required | Not required |
| Memory usage | Higher | Lower |
| Training speed | Slower | Faster |
| Use case | Fine-grained alignment | Combined SFT + alignment |
Collecting Preference Data
Human Annotation
- Generate multiple responses per prompt
- Have annotators rank responses
- Create chosen/rejected pairs
LLM-as-Judge
def create_preference_pairs(prompts, model_responses):
"""Use GPT-4 to judge which response is better."""
# ... generate judgments
return {"prompt": p, "chosen": better, "rejected": worse}
Next Steps