Saltar al contenido principal

Modelado de Recompensas

Entrena modelos de recompensa que puntúan respuestas de texto para uso en entrenamiento PPO/RLHF.
Importante: Los modelos de recompensa NO son generadores de texto. Producen una puntuación escalar para un texto dado, usada para proporcionar recompensas durante el entrenamiento PPO. No puedes usar un modelo de recompensa como un LLM normal para generación de texto.

Inicio Rápido

aitraining llm --train \
  --model google/gemma-3-270m \
  --data-path ./preferences.jsonl \
  --project-name reward-model \
  --trainer reward \
  --prompt-text-column prompt \
  --text-column chosen \
  --rejected-text-column rejected

Python API

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./preferences.jsonl",
    project_name="reward-model",

    trainer="reward",

    # Column mappings (required for reward training)
    prompt_text_column="prompt",
    text_column="chosen",
    rejected_text_column="rejected",

    epochs=1,
    batch_size=4,
    lr=2e-5,
)

project = AutoTrainProject(params=params, backend="local", process=True)
project.create()

Formato de Datos

El entrenamiento de recompensa requiere datos de preferencia con tres columnas:
ColumnaDescripción
promptEl prompt/pregunta de entrada
chosenLa respuesta preferida/mejor
rejectedLa respuesta menos preferida/peor

Datos de Ejemplo

{"prompt": "Explain gravity", "chosen": "Gravity is a fundamental force...", "rejected": "gravity makes stuff fall down"}
{"prompt": "What is Python?", "chosen": "Python is a high-level programming language...", "rejected": "its a snake"}
{"prompt": "Write a greeting", "chosen": "Hello! How can I assist you today?", "rejected": "hey"}

Parámetros Requeridos

El entrenamiento de recompensa requiere que se especifiquen los tres parámetros de columna:
  • --prompt-text-column
  • --text-column (para respuestas elegidas)
  • --rejected-text-column

Parámetros

ParámetroFlag CLIPor DefectoDescripción
prompt_text_column--prompt-text-columnpromptColumna con prompts
text_column--text-columntextColumna con respuestas elegidas
rejected_text_column--rejected-text-columnrejectedColumna con respuestas rechazadas

Modelo de Salida

El modelo entrenado es un AutoModelForSequenceClassification que:
  • Toma entrada de texto
  • Retorna una puntuación de recompensa escalar
  • Puntuaciones más altas indican respuestas mejores
  • Usado como entrada para entrenamiento PPO vía --rl-reward-model-path

Usando el Modelo de Recompensa

Con Entrenamiento PPO

aitraining llm --train \
  --model google/gemma-3-270m \
  --data-path ./prompts.jsonl \
  --project-name ppo-model \
  --trainer ppo \
  --rl-reward-model-path ./reward-model

Inferencia Directa

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Cargar modelo de recompensa
model = AutoModelForSequenceClassification.from_pretrained("./reward-model")
tokenizer = AutoTokenizer.from_pretrained("./reward-model")

# Puntuar una respuesta
text = "What is AI? AI is artificial intelligence, a field of computer science..."
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)

with torch.no_grad():
    outputs = model(**inputs)
    score = outputs.logits.item()

print(f"Reward score: {score}")

Mejores Prácticas

  1. Datos de preferencia de calidad - El modelo de recompensa es tan bueno como tus anotaciones
  2. Ejemplos diversos - Incluye prompts variados y niveles de calidad de respuesta
  3. Señales de preferencia claras - Elegido debe ser claramente mejor que rechazado
  4. Dataset balanceado - Evita sesgo hacia ciertos tipos de respuesta
  5. Datos suficientes - Apunta a mínimo 1,000+ pares de preferencia

Ejemplo: Construyendo Datos de Preferencia

# Script de ejemplo para crear datos de preferencia
import json

preferences = [
    {
        "prompt": "Summarize machine learning",
        "chosen": "Machine learning is a subset of AI that enables systems to learn from data...",
        "rejected": "ml is computers learning stuff"
    },
    # Añade más ejemplos...
]

with open("preferences.jsonl", "w") as f:
    for item in preferences:
        f.write(json.dumps(item) + "\n")

Próximos Pasos