Modelagem de Recompensa
Treine modelos de recompensa que pontuam respostas de texto para uso em treinamento PPO/RLHF.
Importante: Modelos de recompensa NÃO são geradores de texto. Eles produzem uma pontuação escalar para um dado texto, usada para fornecer recompensas durante o treinamento PPO. Você não pode usar um modelo de recompensa como um LLM normal para geração de texto.
Início Rápido
aitraining llm --train \
--model google/gemma-3-270m \
--data-path ./preferences.jsonl \
--project-name reward-model \
--trainer reward \
--prompt-text-column prompt \
--text-column chosen \
--rejected-text-column rejected
Python API
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./preferences.jsonl",
project_name="reward-model",
trainer="reward",
# Column mappings (required for reward training)
prompt_text_column="prompt",
text_column="chosen",
rejected_text_column="rejected",
epochs=1,
batch_size=4,
lr=2e-5,
)
project = AutoTrainProject(params=params, backend="local", process=True)
project.create()
O treinamento de recompensa requer dados de preferência com três colunas:
| Coluna | Descrição |
|---|
prompt | O prompt/pergunta de entrada |
chosen | A resposta preferida/melhor |
rejected | A resposta menos preferida/pior |
Exemplo de Dados
{"prompt": "Explain gravity", "chosen": "Gravity is a fundamental force...", "rejected": "gravity makes stuff fall down"}
{"prompt": "What is Python?", "chosen": "Python is a high-level programming language...", "rejected": "its a snake"}
{"prompt": "Write a greeting", "chosen": "Hello! How can I assist you today?", "rejected": "hey"}
Parâmetros Obrigatórios
O treinamento de recompensa requer que todos os três parâmetros de coluna sejam especificados:
--prompt-text-column
--text-column (para respostas escolhidas)
--rejected-text-column
Parâmetros
| Parâmetro | Flag CLI | Padrão | Descrição |
|---|
prompt_text_column | --prompt-text-column | prompt | Coluna com prompts |
text_column | --text-column | text | Coluna com respostas escolhidas |
rejected_text_column | --rejected-text-column | rejected | Coluna com respostas rejeitadas |
Modelo de Saída
O modelo treinado é um AutoModelForSequenceClassification que:
- Recebe entrada de texto
- Retorna uma pontuação de recompensa escalar
- Pontuações mais altas indicam respostas melhores
- Usado como entrada para treinamento PPO via
--rl-reward-model-path
Usando o Modelo de Recompensa
aitraining llm --train \
--model google/gemma-3-270m \
--data-path ./prompts.jsonl \
--project-name ppo-model \
--trainer ppo \
--rl-reward-model-path ./reward-model
Inferência Direta
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
# Carregar modelo de recompensa
model = AutoModelForSequenceClassification.from_pretrained("./reward-model")
tokenizer = AutoTokenizer.from_pretrained("./reward-model")
# Pontuar uma resposta
text = "What is AI? AI is artificial intelligence, a field of computer science..."
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
score = outputs.logits.item()
print(f"Reward score: {score}")
Melhores Práticas
- Dados de preferência de qualidade - O modelo de recompensa é tão bom quanto suas anotações
- Exemplos diversos - Inclua prompts variados e níveis de qualidade de resposta
- Sinais de preferência claros - Escolhido deve ser claramente melhor que rejeitado
- Dataset balanceado - Evite viés em direção a certos tipos de resposta
- Dados suficientes - Procure por pelo menos 1.000+ pares de preferência
Exemplo: Construindo Dados de Preferência
# Script de exemplo para criar dados de preferência
import json
preferences = [
{
"prompt": "Summarize machine learning",
"chosen": "Machine learning is a subset of AI that enables systems to learn from data...",
"rejected": "ml is computers learning stuff"
},
# Adicione mais exemplos...
]
with open("preferences.jsonl", "w") as f:
for item in preferences:
f.write(json.dumps(item) + "\n")
Próximos Passos