RL Training Module
AITraining includes a comprehensive reinforcement learning module for advanced LLM training scenarios.
The CLI commands (--trainer ppo, --trainer dpo, --trainer reward) use TRL library implementations for stability. The autotrain.trainers.rl module documented here provides lower-level building blocks for custom RL training pipelines.
Overview
The RL module provides:
- PPO Trainer - Proximal Policy Optimization with KL penalty and GAE
- DPO Trainer - Direct Preference Optimization from preference data
- Reward Models - Standard, pairwise, and multi-objective reward models
- RL Environments - Text generation, math, code, and preference comparison environments
- Async Pipeline - Forward-backward training with gradient accumulation
PPO Training
Configuration
from autotrain.trainers.rl import PPOConfig, PPOTrainer
config = PPOConfig(
model_name="google/gemma-2-2b",
learning_rate=1e-5,
batch_size=16,
mini_batch_size=4,
gradient_accumulation_steps=1,
# PPO hyperparameters
ppo_epochs=4,
gamma=0.99, # Discount factor
lam=0.95, # GAE lambda
clip_ratio=0.2, # PPO clip ratio
value_clip=0.2, # Value function clip
max_grad_norm=1.0, # Gradient clipping
# KL penalty
kl_penalty_coef=0.01,
kl_target=0.01,
kl_horizon=10000, # Horizon for adaptive KL
# Coefficients
entropy_coef=0.01, # Entropy regularization
value_coef=0.5, # Value function coefficient
# Generation
max_new_tokens=128,
temperature=0.7,
top_p=0.9,
# Training loop
num_iterations=100,
save_every=10,
eval_every=5,
device=None, # Auto-detected
)
PPO Architecture
The PPO implementation uses a PPOModel wrapper that adds a value head to any causal LM:
# PPOModel wraps base model with ValueHead
class PPOModel(nn.Module):
def __init__(self, base_model):
self.base_model = base_model
self.value_head = ValueHead(hidden_size)
# ValueHead architecture
class ValueHead(nn.Module):
# hidden -> ReLU -> output (scalar value)
Adaptive KL Controller
The AdaptiveKLController automatically adjusts the KL penalty coefficient to keep KL divergence near the target:
# Automatically managed by PPOTrainer
# Adjusts kl_penalty_coef based on current KL vs target
Training Loop
# Initialize trainer with custom reward function
def my_reward_fn(prompts, responses, metadata=None):
rewards = []
for response in responses:
score = evaluate_response(response)
rewards.append(score)
return rewards
trainer = PPOTrainer(
config=config,
tokenizer=tokenizer, # Optional, loaded from model if not provided
reward_fn=my_reward_fn,
)
# Train on prompts
prompts = ["Write a poem about...", "Explain quantum..."]
metrics = trainer.train(prompts, num_iterations=100)
Key Features
| Feature | Description |
|---|
| Adaptive KL Controller | Automatically adjusts KL penalty coefficient based on current vs target KL |
| GAE Advantage Estimation | Generalized Advantage Estimation for stable training |
| Value Head | Separate value function for critic (PPOModel wraps base model) |
| Reference Model | Frozen copy to prevent drift |
| Async Training | Uses AsyncTrainingClient for efficient forward-backward |
DPO Training
Train directly from preference data without a separate reward model.
Configuration
from autotrain.trainers.rl import DPOConfig, DPOTrainer
config = DPOConfig(
model_name="google/gemma-2-2b",
learning_rate=1e-6,
batch_size=8,
gradient_accumulation_steps=2,
# DPO hyperparameters
beta=0.1, # Temperature parameter
label_smoothing=0.0, # For robustness
reference_free=False, # Use reference model
# Training
num_epochs=1,
max_grad_norm=1.0,
warmup_ratio=0.1,
# Sequence lengths
max_length=512,
max_prompt_length=256,
# Checkpointing
eval_every=100,
save_every=500,
device=None, # Auto-detected
)
Preference Dataset
from autotrain.trainers.rl.dpo import PreferenceDataset
# Create dataset from preference pairs
dataset = PreferenceDataset(
prompts=["What is AI?", "Explain gravity"],
chosen=["AI is artificial intelligence...", "Gravity is a force..."],
rejected=["idk lol", "its like magnets"],
tokenizer=tokenizer,
max_length=512,
max_prompt_length=256,
)
# Train
trainer = DPOTrainer(config=config, tokenizer=tokenizer)
metrics = trainer.train(dataset, eval_dataset=eval_dataset)
PreferenceDataset must be imported directly from autotrain.trainers.rl.dpo as it’s not exported in the main __init__.py.
Reference-Free DPO
For training without a reference model:
config = DPOConfig(
model_name="google/gemma-2-2b",
reference_free=True, # No reference model needed
beta=0.1,
)
Reward Models
Standard Reward Model
from autotrain.trainers.rl import RewardModel, RewardModelConfig, RewardModelTrainer
config = RewardModelConfig(
model_name="bert-base-uncased",
num_labels=1,
pooling_strategy="last", # "mean", "last", or "cls"
dropout_prob=0.1,
temperature=1.0, # Temperature scaling for rewards
# LoRA settings
use_lora=True,
lora_rank=8,
lora_alpha=16,
lora_dropout=0.1,
# Training
learning_rate=1e-4,
warmup_steps=100,
gradient_accumulation_steps=1,
)
model = RewardModel(config)
Training on Preferences
trainer = RewardModelTrainer(
model=model,
tokenizer=tokenizer,
config=config,
device=None, # Auto-detected
)
trainer.train_on_preferences(
chosen_texts=["Good response 1", "Good response 2"],
rejected_texts=["Bad response 1", "Bad response 2"],
num_epochs=3,
batch_size=8,
)
# Save/load
trainer.save_model("reward_model.pt")
trainer.load_model("reward_model.pt")
Pairwise Reward Model
For direct preference comparison using Bradley-Terry model:
from autotrain.trainers.rl import PairwiseRewardModel
model = PairwiseRewardModel(config)
# Forward pass compares two inputs
preference_score = model.forward_pair(
input_ids_a, attention_mask_a,
input_ids_b, attention_mask_b,
)
# Bradley-Terry loss for training
loss = model.compute_bradley_terry_loss(
input_ids_a, attention_mask_a,
input_ids_b, attention_mask_b,
labels, # 1 if A preferred, 0 if B preferred
)
Multi-Objective Reward Model
Combine multiple reward signals:
from autotrain.trainers.rl import MultiObjectiveRewardModel
model = MultiObjectiveRewardModel(
config=config,
num_objectives=3,
objective_weights=[0.5, 0.3, 0.2], # Helpfulness, safety, honesty
)
# Get all objectives
outputs = model(input_ids, attention_mask, return_all_objectives=True)
# outputs["rewards"] shape: (batch_size, 3)
# outputs["combined_reward"] shape: (batch_size, 1)
# Multi-objective loss
loss, per_objective_losses = model.compute_multi_objective_loss(
input_ids, attention_mask,
target_rewards, # Shape: (batch_size, num_objectives)
objective_mask=None, # Optional: which objectives to train
)
RL Environments
Environment Dataclasses
from autotrain.trainers.rl.environments import Observation, StepResult, Trajectory
# Observation from environment
@dataclass
class Observation:
input_ids: torch.Tensor
attention_mask: torch.Tensor
prompt: str
metadata: Dict[str, Any]
# Result from env.step()
@dataclass
class StepResult:
reward: float
done: bool
next_observation: Optional[Observation]
info: Dict[str, Any]
metrics: Dict[str, float]
# Full episode trajectory
@dataclass
class Trajectory:
observations: List[Observation]
actions: List[torch.Tensor]
rewards: List[float]
logprobs: List[torch.Tensor]
done: bool
total_reward: float
metrics: Dict[str, Any]
Text Generation Environment
from autotrain.trainers.rl import TextGenerationEnv
env = TextGenerationEnv(
tokenizer=tokenizer,
prompts=["Write a story about...", "Explain how..."],
max_length=512,
reward_fn=my_reward_function, # Optional, default is length-based
stop_sequences=["</s>", "\n\n"],
temperature=1.0,
)
# Reset and step
observation = env.reset()
result = env.step(action_token)
# result.reward, result.done, result.next_observation
# Render current state
print(env.render())
Multi-Objective Environment
from autotrain.trainers.rl import MultiObjectiveRewardEnv
def correctness_reward(prompt, generated, full_text):
return 1.0 if is_correct(generated) else 0.0
def formatting_reward(prompt, generated, full_text):
return 0.5 if properly_formatted(generated) else 0.0
env = MultiObjectiveRewardEnv(
tokenizer=tokenizer,
prompts=prompts,
reward_components={
"correctness": correctness_reward,
"formatting": formatting_reward,
},
reward_weights={
"correctness": 1.0,
"formatting": 0.1,
},
)
# Step returns component rewards in metrics
result = env.step(action)
# result.metrics["reward_correctness"], result.metrics["reward_formatting"]
Preference Comparison Environment
For RLHF and DPO data collection:
from autotrain.trainers.rl import PreferenceComparisonEnv
env = PreferenceComparisonEnv(
tokenizer=tokenizer,
prompts=prompts,
preference_model=preference_model, # Optional
human_feedback_fn=feedback_fn, # Optional callback
max_length=512,
)
# Generates pairs of responses and computes preference
observation = env.reset()
result1 = env.step(response1_tokens) # First response
result2 = env.step(response2_tokens) # Second response, computes preference
Built-in Environments
from autotrain.trainers.rl import create_math_problem_env, create_code_generation_env
# Math problem solving (correctness + formatting rewards)
math_env = create_math_problem_env(tokenizer)
# Code generation (syntax + style rewards)
code_env = create_code_generation_env(tokenizer)
Forward-Backward Pipeline
Async training with gradient accumulation:
from autotrain.trainers.rl import ForwardBackwardPipeline
# Low-level pipeline
pipeline = ForwardBackwardPipeline(
model=model,
device="cuda",
max_workers=2, # Thread pool size
gradient_accumulation_steps=4,
)
# Queue forward-backward pass
future = pipeline.forward_backward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_fn="cross_entropy",
)
# Get result (blocks until complete)
result = future.result()
print(f"Loss: {result.loss}")
# Queue optimizer step
optim_future = pipeline.optim_step(
optimizer=optimizer,
scheduler=scheduler, # Optional
max_grad_norm=1.0,
)
optim_result = optim_future.result()
Built-in Loss Functions
The pipeline supports several built-in loss functions:
| Loss Function | Description | Required kwargs |
|---|
"cross_entropy" | Standard language modeling loss | None |
"importance_sampling" | RL with importance sampling | old_logprobs, advantages |
"ppo" | Full PPO loss | old_log_probs, advantages, optionally values, returns |
Custom Loss Functions
def custom_loss_fn(model, inputs, outputs, **kwargs):
# Your custom loss computation
logits = outputs.logits
# ... compute loss ...
return loss_tensor # Must be scalar
future = pipeline.forward_backward_custom(
input_ids=input_ids,
custom_loss_fn=custom_loss_fn,
attention_mask=attention_mask, # Optional
my_param=42, # Passed to loss function via kwargs
)
High-Level Client
from autotrain.trainers.rl.forward_backward import AsyncTrainingClient
client = AsyncTrainingClient(
model=model,
reference_model=reference_model, # For PPO/DPO
device="cuda",
gradient_accumulation_steps=4,
)
# Training step
fwd_future = client.forward_backward(batch, loss_fn="cross_entropy")
optim_future = client.optim_step(optimizer, max_grad_norm=1.0)
# Forward only (for reference model)
ref_future = client.forward(batch, use_reference=True)
# Clean up
client.shutdown()
AsyncTrainingClient must be imported directly from autotrain.trainers.rl.forward_backward as it’s not exported in the main __init__.py.
Checkpointing
# Save checkpoint
checkpoint_info = pipeline.save_state("checkpoint_1000")
# Returns: {"path": ..., "model_path": ..., "optimizer_path": ..., "state_path": ...}
# Load checkpoint
pipeline.load_state("checkpoints/checkpoint_1000")
Sampling
Generate samples during training:
samples = pipeline.sample(
prompt=prompt_tokens, # List[int] or Tensor
max_tokens=100,
temperature=0.7,
top_k=50,
top_p=0.9,
stop=[tokenizer.eos_token_id],
)
print(f"Generated: {samples['tokens']}")
print(f"Logprobs: {samples['logprobs']}")
print(f"Prompt: {samples['prompt']}")
Best Practices
PPO Training
- Start with small KL coefficient - Let the adaptive controller adjust
- Use gradient accumulation - Larger effective batch sizes are more stable
- Monitor KL divergence - Should stay close to target
- Warm up the value function - Train critic before full PPO
DPO Training
- High-quality preference data - Quality matters more than quantity
- Low learning rate - 1e-6 to 1e-5 recommended
- Label smoothing - 0.1 can improve robustness
- Evaluate frequently - Track accuracy and reward margin
Reward Modeling
- Balanced data - Equal chosen/rejected examples
- Diverse prompts - Cover expected use cases
- LoRA for efficiency - Fine-tune large models efficiently
- Multi-objective - Separate safety and helpfulness signals
CLI Integration
For production use, the CLI provides simpler interfaces using TRL implementations:
# PPO training (uses TRL PPOTrainer)
aitraining llm --train \
--model google/gemma-2-2b \
--trainer ppo \
--reward-model ./my-reward-model
# DPO training (uses TRL DPOTrainer)
aitraining llm --train \
--model google/gemma-2-2b \
--trainer dpo \
--dpo-beta 0.1
# Reward model training
aitraining llm --train \
--model google/gemma-2-2b \
--trainer reward \
--data-path ./preference_data.jsonl
Next Steps