Python SDK 参考
AITraining Python API 的全面指南。
pip install aitraining torch torchvision torchaudio
包名 vs 导入名:使用 pip install aitraining 安装,但使用 from autotrain import ... 导入
LLM 训练
LLMTrainingParams
from autotrain.trainers.clm.params import LLMTrainingParams
params = LLMTrainingParams(
# Required
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
# Training method
trainer="sft", # sft, dpo, orpo, ppo, reward, distillation
# Training settings
epochs=3,
batch_size=2,
lr=2e-5,
gradient_accumulation=4,
mixed_precision="bf16",
# LoRA
peft=True,
lora_r=16,
lora_alpha=32,
lora_dropout=0.05,
# Data processing
text_column="text",
block_size=2048,
add_eos_token=True,
save_processed_data="auto", # auto, local, hub, both, none
# Logging
log="wandb",
logging_steps=-1, # Default: -1 (auto)
# Hyperparameter sweep (optional)
# use_sweep=True,
# sweep_backend="optuna",
# sweep_n_trials=20,
# sweep_params='{"lr": {"type": "loguniform", "low": 1e-5, "high": 1e-3}}',
)
关键参数
| 参数 | 类型 | 描述 |
|---|
model | str | 模型名称或路径 |
data_path | str | 训练数据路径 |
project_name | str | 输出目录 |
trainer | str | 训练方法 |
epochs | int | 训练轮数 |
batch_size | int | 批次大小 |
lr | float | 学习率 |
peft | bool | 启用 LoRA |
lora_r | int | LoRA 秩 |
lora_alpha | int | LoRA alpha |
save_processed_data | str | 保存处理后数据:auto、local、hub、both、none |
超参数扫描参数
| 参数 | 类型 | 描述 |
|---|
use_sweep | bool | 启用超参数扫描 |
sweep_backend | str | 后端:optuna、grid、random |
sweep_n_trials | int | 试验次数 |
sweep_metric | str | 要优化的指标 |
sweep_direction | str | minimize 或 maximize |
sweep_params | str | 自定义搜索空间(JSON 字符串) |
wandb_sweep | bool | 启用 W&B 原生 sweep 仪表板 |
wandb_sweep_project | str | sweep 的 W&B 项目 |
wandb_sweep_entity | str | W&B 实体(团队/用户名) |
wandb_sweep_id | str | 要继续的现有 sweep ID |
sweep_params 格式
支持列表和字典格式:
import json
# 字典格式(推荐)
sweep_params = json.dumps({
"lr": {"type": "loguniform", "low": 1e-5, "high": 1e-3},
"batch_size": {"type": "categorical", "values": [2, 4, 8]},
})
# 列表格式(分类值简写)
sweep_params = json.dumps({
"batch_size": [2, 4, 8],
})
支持的类型:categorical、loguniform、uniform、int。
文本分类
TextClassificationParams
from autotrain.trainers.text_classification.params import TextClassificationParams
params = TextClassificationParams(
model="bert-base-uncased",
data_path="./reviews.csv",
project_name="sentiment",
text_column="text",
target_column="label",
epochs=5,
batch_size=16,
lr=2e-5,
)
图像分类
ImageClassificationParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
params = ImageClassificationParams(
model="google/vit-base-patch16-224",
data_path="./images/",
project_name="classifier",
image_column="image",
target_column="label",
epochs=10,
batch_size=32,
)
项目执行
AutoTrainProject
from autotrain.project import AutoTrainProject
# Create and run project
project = AutoTrainProject(
params=params,
backend="local", # "local" or "spaces"
process=True # Start training immediately
)
job_id = project.create()
后端选项
| 后端 | 描述 |
|---|
local | 在本地机器上运行 |
spaces-* | 在 Hugging Face Spaces 上运行(例如:spaces-a10g-large、spaces-t4-medium) |
ep-* | Hugging Face Endpoints |
ngc-* | NVIDIA NGC |
nvcf-* | NVIDIA Cloud Functions |
使用 Completers
from autotrain.generation import CompletionConfig, create_completer
# Configure generation
config = CompletionConfig(
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
top_k=50,
)
# Create completer (first param is "model", not "model_path")
completer = create_completer(
model="./my-trained-model",
completer_type="message",
config=config
)
# Generate (returns MessageCompletionResult)
result = completer.chat("Hello, how are you?")
print(result.content) # Access the text content
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model
model = AutoModelForCausalLM.from_pretrained("./my-model")
tokenizer = AutoTokenizer.from_pretrained("./my-model")
# Generate
inputs = tokenizer("Hello!", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
数据集处理
AutoTrainDataset
from autotrain.dataset import AutoTrainDataset
dataset = AutoTrainDataset(
train_data=["train.csv"],
task="text_classification",
token="hf_...",
project_name="my-project",
username="my-username",
column_mapping={
"text": "review_text",
"label": "sentiment"
},
)
# Prepare dataset
data_path = dataset.prepare()
配置文件
从 YAML 加载
from autotrain.parser import AutoTrainConfigParser
# Parse config file
parser = AutoTrainConfigParser("config.yaml")
# Run training
parser.run()
错误处理
from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams
try:
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./data.jsonl",
project_name="my-model",
)
project = AutoTrainProject(params=params, backend="local", process=True)
job_id = project.create()
except ValueError as e:
print(f"Configuration error: {e}")
except RuntimeError as e:
print(f"Training error: {e}")
完整示例
SFT 训练流程
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
def train_sft():
params = LLMTrainingParams(
model="google/gemma-3-270m",
data_path="./conversations.jsonl",
project_name="gemma-sft",
trainer="sft",
epochs=3,
batch_size=2,
gradient_accumulation=8,
lr=2e-5,
peft=True,
lora_r=16,
lora_alpha=32,
log="wandb",
)
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
return project.create()
if __name__ == "__main__":
job_id = train_sft()
print(f"Job ID: {job_id}")
DPO 训练流程
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
def train_dpo():
params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B",
data_path="./preferences.jsonl",
project_name="llama-dpo",
trainer="dpo",
dpo_beta=0.1,
max_prompt_length=128, # Default: 128
max_completion_length=None, # Default: None
epochs=1,
batch_size=2,
lr=5e-6,
peft=True,
lora_r=16,
)
project = AutoTrainProject(
params=params,
backend="local",
process=True
)
return project.create()
下一步