Pular para o conteúdo principal

Modelagem de Recompensa

Treine modelos de recompensa que pontuam respostas de texto para uso em treinamento PPO/RLHF.
Importante: Modelos de recompensa NÃO são geradores de texto. Eles produzem uma pontuação escalar para um dado texto, usada para fornecer recompensas durante o treinamento PPO. Você não pode usar um modelo de recompensa como um LLM normal para geração de texto.

Início Rápido

aitraining llm --train \
  --model google/gemma-3-270m \
  --data-path ./preferences.jsonl \
  --project-name reward-model \
  --trainer reward \
  --prompt-text-column prompt \
  --text-column chosen \
  --rejected-text-column rejected

Python API

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./preferences.jsonl",
    project_name="reward-model",

    trainer="reward",

    # Column mappings (required for reward training)
    prompt_text_column="prompt",
    text_column="chosen",
    rejected_text_column="rejected",

    epochs=1,
    batch_size=4,
    lr=2e-5,
)

project = AutoTrainProject(params=params, backend="local", process=True)
project.create()

Formato dos Dados

O treinamento de recompensa requer dados de preferência com três colunas:
ColunaDescrição
promptO prompt/pergunta de entrada
chosenA resposta preferida/melhor
rejectedA resposta menos preferida/pior

Exemplo de Dados

{"prompt": "Explain gravity", "chosen": "Gravity is a fundamental force...", "rejected": "gravity makes stuff fall down"}
{"prompt": "What is Python?", "chosen": "Python is a high-level programming language...", "rejected": "its a snake"}
{"prompt": "Write a greeting", "chosen": "Hello! How can I assist you today?", "rejected": "hey"}

Parâmetros Obrigatórios

O treinamento de recompensa requer que todos os três parâmetros de coluna sejam especificados:
  • --prompt-text-column
  • --text-column (para respostas escolhidas)
  • --rejected-text-column

Parâmetros

ParâmetroFlag CLIPadrãoDescrição
prompt_text_column--prompt-text-columnpromptColuna com prompts
text_column--text-columntextColuna com respostas escolhidas
rejected_text_column--rejected-text-columnrejectedColuna com respostas rejeitadas

Modelo de Saída

O modelo treinado é um AutoModelForSequenceClassification que:
  • Recebe entrada de texto
  • Retorna uma pontuação de recompensa escalar
  • Pontuações mais altas indicam respostas melhores
  • Usado como entrada para treinamento PPO via --rl-reward-model-path

Usando o Modelo de Recompensa

Com Treinamento PPO

aitraining llm --train \
  --model google/gemma-3-270m \
  --data-path ./prompts.jsonl \
  --project-name ppo-model \
  --trainer ppo \
  --rl-reward-model-path ./reward-model

Inferência Direta

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Carregar modelo de recompensa
model = AutoModelForSequenceClassification.from_pretrained("./reward-model")
tokenizer = AutoTokenizer.from_pretrained("./reward-model")

# Pontuar uma resposta
text = "What is AI? AI is artificial intelligence, a field of computer science..."
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)

with torch.no_grad():
    outputs = model(**inputs)
    score = outputs.logits.item()

print(f"Reward score: {score}")

Melhores Práticas

  1. Dados de preferência de qualidade - O modelo de recompensa é tão bom quanto suas anotações
  2. Exemplos diversos - Inclua prompts variados e níveis de qualidade de resposta
  3. Sinais de preferência claros - Escolhido deve ser claramente melhor que rejeitado
  4. Dataset balanceado - Evite viés em direção a certos tipos de resposta
  5. Dados suficientes - Procure por pelo menos 1.000+ pares de preferência

Exemplo: Construindo Dados de Preferência

# Script de exemplo para criar dados de preferência
import json

preferences = [
    {
        "prompt": "Summarize machine learning",
        "chosen": "Machine learning is a subset of AI that enables systems to learn from data...",
        "rejected": "ml is computers learning stuff"
    },
    # Adicione mais exemplos...
]

with open("preferences.jsonl", "w") as f:
    for item in preferences:
        f.write(json.dumps(item) + "\n")

Próximos Passos