Saltar al contenido principal

API de Entrenamiento LLM

Referencia completa de la API para entrenamiento de LLM.

LLMTrainingParams

La clase principal de configuración para entrenamiento de LLM.
from autotrain.trainers.clm.params import LLMTrainingParams

Parámetros Básicos

params = LLMTrainingParams(
    # Core parameters (always specify these)
    model="google/gemma-3-270m",       # Default: "google/gemma-3-270m"
    data_path="./data.jsonl",          # Default: "data"
    project_name="my-model",           # Default: "project-name"

    # Data splits
    train_split="train",               # Default: "train"
    valid_split=None,                  # Default: None
    max_samples=None,                  # Default: None (use all)
)

Selección de Trainer

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    trainer="sft",  # Default: "default" (pretraining). Options: sft, dpo, orpo, reward
)

Hiperparámetros de Entrenamiento

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Core hyperparameters (showing defaults)
    epochs=1,           # Default: 1
    batch_size=2,       # Default: 2
    lr=3e-5,            # Default: 3e-5
    warmup_ratio=0.1,   # Default: 0.1
    gradient_accumulation=4,  # Default: 4
    weight_decay=0.0,   # Default: 0.0
    max_grad_norm=1.0,  # Default: 1.0

    # Precision
    mixed_precision=None,  # Default: None (options: bf16, fp16, None)

    # Optimization
    optimizer="adamw_torch",  # Default: adamw_torch
    scheduler="linear",       # Default: linear
    seed=42,                  # Default: 42
)

Configuración PEFT/LoRA

params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./data.jsonl",
    project_name="my-model",

    # Enable LoRA (default: False)
    peft=True,
    lora_r=16,           # Default: 16
    lora_alpha=32,       # Default: 32
    lora_dropout=0.05,   # Default: 0.05
    target_modules="all-linear",  # Default: all-linear

    # Quantization (optional)
    quantization="int4",  # Options: int4, int8, or None (default: None)

    # Merge after training (default is True - LoRA merged automatically)
    merge_adapter=True,
)

Procesamiento de Datos

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Text processing
    text_column="text",
    block_size=-1,            # Default: -1 (model default)
    model_max_length=2048,    # Default: 2048
    add_eos_token=True,       # Default: True
    padding="right",          # Default: "right"

    # Chat format
    chat_template=None,       # Auto-detect or specify
    apply_chat_template=True, # Default: True

    # Efficiency
    packing=None,             # Default: None (set True to enable)
    use_flash_attention_2=False,  # Default: False
    attn_implementation=None,     # Default: None
)

Parámetros DPO

params = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="./preferences.jsonl",
    project_name="my-model",

    trainer="dpo",

    # DPO-specific
    dpo_beta=0.1,              # Default: 0.1
    max_prompt_length=128,     # Default: 128
    max_completion_length=None, # Default: None

    # Reference model (optional)
    model_ref=None,  # Uses same as model if None

    # Data columns (required for DPO)
    prompt_text_column="prompt",
    text_column="chosen",
    rejected_text_column="rejected",
)

Parámetros ORPO

params = LLMTrainingParams(
    model="google/gemma-2-2b",
    data_path="./preferences.jsonl",
    project_name="my-model",

    trainer="orpo",

    # ORPO-specific
    dpo_beta=0.1,              # Default: 0.1
    max_prompt_length=128,     # Default: 128
    max_completion_length=None, # Default: None

    # Data columns (required for ORPO)
    prompt_text_column="prompt",
    text_column="chosen",
    rejected_text_column="rejected",
)

Distilación de Conocimiento

params = LLMTrainingParams(
    model="google/gemma-3-270m",           # Student
    teacher_model="google/gemma-2-2b",     # Teacher
    data_path="./prompts.jsonl",
    project_name="distilled-model",

    use_distillation=True,
    distill_temperature=3.0,   # Default: 3.0
    distill_alpha=0.7,         # Default: 0.7
    distill_max_teacher_length=512,  # Default: 512
)

Logging y Guardado

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Logging
    log="wandb",  # wandb, tensorboard, or None (default: wandb)
    logging_steps=-1,     # Default: -1 (auto)
    wandb_visualizer=True,  # Terminal visualizer
    wandb_token=None,       # W&B API token (optional)

    # Checkpointing
    save_strategy="steps",  # steps or epoch (default: epoch)
    save_steps=500,
    save_total_limit=1,   # Default: 1
    eval_strategy="steps",
)

Integración con Hub

params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Push to Hub
    push_to_hub=True,
    username="your-username",
    token="hf_...",
)

Ejecutar Entrenamiento

from autotrain.project import AutoTrainProject

# Create and run project
project = AutoTrainProject(
    params=params,
    backend="local",
    process=True
)

job_id = project.create()

Ejemplo Completo

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

# Full configuration
params = LLMTrainingParams(
    # Model
    model="meta-llama/Llama-3.2-1B",
    project_name="llama-production",

    # Data
    data_path="./conversations.jsonl",
    train_split="train",
    valid_split="validation",
    text_column="text",
    block_size=2048,

    # Training
    trainer="sft",
    epochs=3,
    batch_size=2,
    gradient_accumulation=8,
    lr=2e-5,
    warmup_ratio=0.1,
    mixed_precision="bf16",

    # LoRA
    peft=True,
    lora_r=32,
    lora_alpha=64,
    lora_dropout=0.05,

    # Optimization
    use_flash_attention_2=True,
    packing=True,
    auto_find_batch_size=True,
    unsloth=False,  # Use Unsloth for faster training

    # Distribution (for multi-GPU)
    distributed_backend=None,  # None for auto (DDP), or "deepspeed"

    # Logging
    log="wandb",
    logging_steps=-1,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=1,

    # Hub
    push_to_hub=True,
    username="my-username",
    token="hf_...",
)

# Run training
project = AutoTrainProject(
    params=params,
    backend="local",
    process=True
)
job_id = project.create()

Próximos Pasos