跳转到主要内容

Python API

AITraining 提供了一个 Python API,用于以编程方式访问所有训练功能。

安装

pip install aitraining torch

快速开始

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

# Configure training
params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",
    trainer="sft",
    epochs=3,
    batch_size=4,
    lr=2e-5,
    peft=True,
    lora_r=16,
)

# Start training
project = AutoTrainProject(params=params, backend="local", process=True)
job_id = project.create()
print(f"Training started: {job_id}")

API 结构

训练参数

每种任务类型都有自己的参数类:
任务参数类
LLM 训练LLMTrainingParams
文本分类TextClassificationParams
图像分类ImageClassificationParams
令牌分类TokenClassificationParams
Seq2SeqSeq2SeqParams
表格数据TabularParams
目标检测ObjectDetectionParams
VLMVLMTrainingParams

项目执行

from autotrain.project import AutoTrainProject

# Create project
project = AutoTrainProject(
    params=params,
    backend="local",  # or "spaces"
    process=True      # Start immediately
)

# Run training
job_id = project.create()

示例:完整训练脚本

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_model():
    # Configure parameters
    params = LLMTrainingParams(
        # Model
        model="meta-llama/Llama-3.2-1B",
        project_name="llama-sft",

        # Data
        data_path="./conversations.jsonl",
        train_split="train",
        text_column="text",
        block_size=2048,

        # Training
        trainer="sft",
        epochs=3,
        batch_size=2,
        gradient_accumulation=4,
        lr=2e-5,
        mixed_precision="bf16",

        # LoRA
        peft=True,
        lora_r=16,
        lora_alpha=32,
        lora_dropout=0.05,

        # Logging
        log="wandb",
        logging_steps=10,
    )

    # Start training
    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

if __name__ == "__main__":
    job_id = train_model()
    print(f"Training complete: {job_id}")

核心模块

模块描述
autotrain.project项目执行
autotrain.trainers.clm.paramsLLM 参数
autotrain.trainers.text_classification.params文本分类
autotrain.dataset数据集处理
autotrain.generation推理工具

下一步