Modelado de Recompensas
Entrena modelos de recompensa que puntúan respuestas de texto para uso en entrenamiento PPO/RLHF.
Importante: Los modelos de recompensa NO son generadores de texto. Producen una puntuación escalar para un texto dado, usada para proporcionar recompensas durante el entrenamiento PPO. No puedes usar un modelo de recompensa como un LLM normal para generación de texto.
Inicio Rápido
aitraining llm --train \
--model google/gemma-3-270m \
--data-path ./preferences.jsonl \
--project-name reward-model \
--trainer reward \
--prompt-text-column prompt \
--text-column chosen \
--rejected-text-column rejected
Python API
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./preferences.jsonl",
project_name="reward-model",
trainer="reward",
# Column mappings (required for reward training)
prompt_text_column="prompt",
text_column="chosen",
rejected_text_column="rejected",
epochs=1,
batch_size=4,
lr=2e-5,
)
project = AutoTrainProject(params=params, backend="local", process=True)
project.create()
El entrenamiento de recompensa requiere datos de preferencia con tres columnas:
| Columna | Descripción |
|---|
prompt | El prompt/pregunta de entrada |
chosen | La respuesta preferida/mejor |
rejected | La respuesta menos preferida/peor |
Datos de Ejemplo
{"prompt": "Explain gravity", "chosen": "Gravity is a fundamental force...", "rejected": "gravity makes stuff fall down"}
{"prompt": "What is Python?", "chosen": "Python is a high-level programming language...", "rejected": "its a snake"}
{"prompt": "Write a greeting", "chosen": "Hello! How can I assist you today?", "rejected": "hey"}
Parámetros Requeridos
El entrenamiento de recompensa requiere que se especifiquen los tres parámetros de columna:
--prompt-text-column
--text-column (para respuestas elegidas)
--rejected-text-column
Parámetros
| Parámetro | Flag CLI | Por Defecto | Descripción |
|---|
prompt_text_column | --prompt-text-column | prompt | Columna con prompts |
text_column | --text-column | text | Columna con respuestas elegidas |
rejected_text_column | --rejected-text-column | rejected | Columna con respuestas rechazadas |
Modelo de Salida
El modelo entrenado es un AutoModelForSequenceClassification que:
- Toma entrada de texto
- Retorna una puntuación de recompensa escalar
- Puntuaciones más altas indican respuestas mejores
- Usado como entrada para entrenamiento PPO vía
--rl-reward-model-path
Usando el Modelo de Recompensa
Con Entrenamiento PPO
aitraining llm --train \
--model google/gemma-3-270m \
--data-path ./prompts.jsonl \
--project-name ppo-model \
--trainer ppo \
--rl-reward-model-path ./reward-model
Inferencia Directa
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
# Cargar modelo de recompensa
model = AutoModelForSequenceClassification.from_pretrained("./reward-model")
tokenizer = AutoTokenizer.from_pretrained("./reward-model")
# Puntuar una respuesta
text = "What is AI? AI is artificial intelligence, a field of computer science..."
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
score = outputs.logits.item()
print(f"Reward score: {score}")
Mejores Prácticas
- Datos de preferencia de calidad - El modelo de recompensa es tan bueno como tus anotaciones
- Ejemplos diversos - Incluye prompts variados y niveles de calidad de respuesta
- Señales de preferencia claras - Elegido debe ser claramente mejor que rechazado
- Dataset balanceado - Evita sesgo hacia ciertos tipos de respuesta
- Datos suficientes - Apunta a mínimo 1,000+ pares de preferencia
Ejemplo: Construyendo Datos de Preferencia
# Script de ejemplo para crear datos de preferencia
import json
preferences = [
{
"prompt": "Summarize machine learning",
"chosen": "Machine learning is a subset of AI that enables systems to learn from data...",
"rejected": "ml is computers learning stuff"
},
# Añade más ejemplos...
]
with open("preferences.jsonl", "w") as f:
for item in preferences:
f.write(json.dumps(item) + "\n")
Próximos Pasos