Saltar al contenido principal

API de Entrenamiento LLM

Referencia completa de la API para entrenamiento de LLM.

LLMTrainingParams

La clase principal de configuración para entrenamiento de LLM.
from autotrain.trainers.clm.params import LLMTrainingParams

Parámetros Básicos

params = LLMTrainingParams(
    # Core parameters (always specify these)
    model="google/gemma-3-270m",       # Default: "google/gemma-3-270m"
    data_path="./data.jsonl",          # Default: "data"
    project_name="my-model",           # Default: "project-name"

    # Data splits
    train_split="train",               # Default: "train"
    valid_split=None,                  # Default: None
    max_samples=None,                  # Default: None (use all)
)

Selección de Trainer

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    trainer="sft",  # Default: "default" (pretraining). Options: sft, dpo, orpo, ppo, grpo, reward
)

Hiperparámetros de Entrenamiento

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Core hyperparameters (showing defaults)
    epochs=1,           # Default: 1
    batch_size=2,       # Default: 2
    lr=3e-5,            # Default: 3e-5
    warmup_ratio=0.1,   # Default: 0.1
    gradient_accumulation=4,  # Default: 4
    weight_decay=0.0,   # Default: 0.0
    max_grad_norm=1.0,  # Default: 1.0

    # Precision
    mixed_precision=None,  # Default: None (options: bf16, fp16, None)

    # Optimization
    optimizer="adamw_torch",  # Default: adamw_torch
    scheduler="linear",       # Default: linear
    seed=42,                  # Default: 42
)

Configuración PEFT/LoRA

params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./data.jsonl",
    project_name="my-model",

    # Enable LoRA (default: False)
    peft=True,
    lora_r=16,           # Default: 16
    lora_alpha=32,       # Default: 32
    lora_dropout=0.05,   # Default: 0.05
    target_modules="all-linear",  # Default: all-linear

    # Quantization (optional)
    quantization="int4",  # Options: int4, int8, or None (default: None)

    # Merge after training (default is True - LoRA merged automatically)
    merge_adapter=True,
)

Procesamiento de Datos

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Text processing
    text_column="text",
    block_size=-1,            # Default: -1 (model default)
    model_max_length=2048,    # Auto-detectado desde la config del modelo (ver nota)
    add_eos_token=True,       # Default: True
    padding="right",          # Default: "right"

    # Chat format
    chat_template=None,       # Auto-detect or specify
    apply_chat_template=True, # Default: True

    # Efficiency
    packing=None,             # Default: None (set True to enable)
    use_flash_attention_2=False,  # Default: False
    attn_implementation=None,     # Default: None
)
Auto-detección de model_max_length: Este parámetro ahora se detecta automáticamente desde la configuración del modelo. Por ejemplo, Gemma 2 (8192 tokens) y Gemma 3 (32K-128K tokens según la variante) usarán automáticamente sus longitudes de contexto nativas. El valor predeterminado 2048 solo se usa como respaldo cuando falla la auto-detección. Configura explícitamente para anular.

Parámetros DPO

params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./preferences.jsonl",
    project_name="my-model",

    trainer="dpo",

    # DPO-specific
    dpo_beta=0.1,              # Default: 0.1
    max_prompt_length=128,     # Default: 128
    max_completion_length=None, # Default: None

    # Reference model (optional)
    model_ref=None,  # Uses same as model if None

    # Data columns (required for DPO)
    prompt_text_column="prompt",
    text_column="chosen",
    rejected_text_column="rejected",
)

Parámetros ORPO

params = LLMTrainingParams(
    model="google/gemma-2-2b",
    data_path="./preferences.jsonl",
    project_name="my-model",

    trainer="orpo",

    # ORPO-specific
    dpo_beta=0.1,              # Default: 0.1
    max_prompt_length=128,     # Default: 128
    max_completion_length=None, # Default: None

    # Data columns (required for ORPO)
    prompt_text_column="prompt",
    text_column="chosen",
    rejected_text_column="rejected",
)

Parámetros GRPO

params = LLMTrainingParams(
    model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    project_name="grpo-agent",

    trainer="grpo",

    # Específicos de GRPO (requeridos)
    rl_env_module="my_envs.hotel_env",   # Ruta del módulo Python para el entorno
    rl_env_class="HotelEnv",             # Nombre de la clase en el módulo del entorno
    rl_num_generations=4,                # Default: 4 — completaciones por prompt

    # Parámetros RL compartidos (usados por PPO y GRPO)
    rl_kl_coef=0.1,           # Default: 0.1 — Penalización de divergencia KL (beta)
    rl_clip_range=0.2,        # Default: 0.2 — Rango de recorte (epsilon)
    rl_env_config=None,       # Default: None — Configuración JSON para el constructor del entorno
    rl_max_new_tokens=256,    # Default: 128 — Máximo de tokens por completación
    rl_top_k=50,              # Default: 50
    rl_top_p=1.0,             # Default: 1.0
    rl_temperature=1.0,       # Default: 1.0

    # Aceleración vLLM (opcional)
    use_vllm=False,                    # Default: False — habilitar vLLM para generación más rápida
    vllm_mode="colocate",              # Default: "colocate" — o "server"
    vllm_gpu_memory_utilization=0.3,   # Default: 0.3 — Fracción de memoria GPU para vLLM (colocate)
    vllm_server_url=None,              # Default: None — URL del servidor vLLM (modo server)
    vllm_tensor_parallel_size=1,       # Default: 1 — GPUs para paralelismo tensorial de vLLM
    vllm_server_gpus=1,                # Default: 1 — GPUs reservadas para servidor vLLM
)
GRPO no requiere data_path — el dataset es construido por el método build_dataset() del entorno. Instala pip install aitraining[vllm] para soporte vLLM.
Ver Entrenamiento GRPO para la interfaz completa del entorno (build_dataset, score_episode, get_tools).

Distilación de Conocimiento

params = LLMTrainingParams(
    model="google/gemma-3-270m",           # Student
    teacher_model="google/gemma-2-2b",     # Teacher
    data_path="./prompts.jsonl",
    project_name="distilled-model",

    use_distillation=True,
    distill_temperature=3.0,   # Default: 3.0
    distill_alpha=0.7,         # Default: 0.7
    distill_max_teacher_length=512,  # Default: 512
)

Logging y Guardado

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Logging
    log="wandb",  # wandb, tensorboard, or None (default: wandb)
    logging_steps=-1,     # Default: -1 (auto)
    wandb_visualizer=True,  # Terminal visualizer
    wandb_token=None,       # W&B API token (optional)

    # Checkpointing
    save_strategy="steps",  # steps or epoch (default: epoch)
    save_steps=500,
    save_total_limit=1,   # Default: 1
    eval_strategy="steps",
)

Integración con Hub

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Push to Hub
    push_to_hub=True,
    username="your-username",
    token="hf_...",
)

Ejecutar Entrenamiento

from autotrain.project import AutoTrainProject

# Create and run project
project = AutoTrainProject(
    params=params,
    backend="local",
    process=True
)

job_id = project.create()

Ejemplo Completo

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

# Full configuration
params = LLMTrainingParams(
    # Model
    model="meta-llama/Llama-3.2-1B",
    project_name="llama-production",

    # Data
    data_path="./conversations.jsonl",
    train_split="train",
    valid_split="validation",
    text_column="text",
    block_size=2048,

    # Training
    trainer="sft",
    epochs=3,
    batch_size=2,
    gradient_accumulation=8,
    lr=2e-5,
    warmup_ratio=0.1,
    mixed_precision="bf16",

    # LoRA
    peft=True,
    lora_r=32,
    lora_alpha=64,
    lora_dropout=0.05,

    # Optimization
    use_flash_attention_2=True,
    packing=True,
    auto_find_batch_size=True,
    unsloth=False,  # Use Unsloth for faster training

    # Distribution (for multi-GPU)
    distributed_backend=None,  # None for auto (DDP), or "deepspeed"

    # Logging
    log="wandb",
    logging_steps=-1,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=1,

    # Hub
    push_to_hub=True,
    username="my-username",
    token="hf_...",
)

# Run training
project = AutoTrainProject(
    params=params,
    backend="local",
    process=True
)
job_id = project.create()

Próximos Pasos