跳转到主要内容

知识蒸馏

训练更小、更快的模型,使其模仿更大的教师模型的行为。

什么是蒸馏?

知识蒸馏将知识从大型”教师”模型转移到较小的”学生”模型。学生学会产生与教师相似的输出,获得仅从数据中无法学到的能力。

快速开始

aitraining llm --train \
  --model google/gemma-3-270m \
  --teacher-model google/gemma-2-2b \
  --data-path ./prompts.jsonl \
  --project-name distilled-model \
  --use-distillation \
  --distill-temperature 3.0

Python API

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

params = LLMTrainingParams(
    # Student model (smaller)
    model="google/gemma-3-270m",

    # Teacher model (larger)
    teacher_model="google/gemma-2-2b",

    # Data
    data_path="./prompts.jsonl",
    project_name="distilled-gemma",

    # Enable distillation
    use_distillation=True,
    distill_temperature=3.0,   # Default: 3.0
    distill_alpha=0.7,         # Default: 0.7
    distill_max_teacher_length=512,  # Default: 512

    # Training
    trainer="sft",
    epochs=5,
    batch_size=4,
    lr=1e-4,
)

project = AutoTrainProject(params=params, backend="local", process=True)
project.create()

参数

参数描述默认值
use_distillation启用蒸馏False
teacher_model教师模型路径use_distillation=True 时必需
distill_temperatureSoftmax 温度(推荐 2.0-4.0)3.0
distill_alpha蒸馏损失权重0.7
distill_max_teacher_length教师模型最大 token 数512
teacher_prompt_template教师模型提示词模板None
student_prompt_template学生模型提示词模板"{input}"

温度

更高的温度使教师模型的概率分布更柔和,使学生更容易学习:
  • 1.0: 正常概率
  • 2.0-4.0: 更柔和,更易教学(推荐)
  • >4.0: 非常柔和,可能失去精度

Alpha

控制蒸馏和标准损失之间的平衡:
  • 0.0: 仅标准损失(无蒸馏)
  • 0.5: 平衡
  • 0.7: 默认值(更重视蒸馏)
  • 1.0: 仅蒸馏损失

提示词模板

自定义教师和学生模型的提示词格式化方式:
params = LLMTrainingParams(
    ...
    use_distillation=True,
    teacher_prompt_template="<|system|>You are helpful.<|user|>{input}<|assistant|>",
    student_prompt_template="{input}",
)
使用 {input} 作为实际提示词文本的占位符。

数据格式

简单的提示词对蒸馏效果很好:
{"text": "What is machine learning?"}
{"text": "Explain how neural networks work."}
{"text": "Write a function to sort a list in Python."}
或包含期望输出:
{"prompt": "What is AI?", "response": "..."}

最佳实践

明智选择模型

  • 教师模型应该显著更大(4倍+ 参数)
  • 相同架构族通常效果最好
  • 教师模型应该在目标任务上表现良好

温度调优

# 保守(教师不确定)
distill_temperature=2.0

# 标准(大多数情况,默认)
distill_temperature=3.0

# 激进(教师自信)- 在推荐范围上限
distill_temperature=4.0
推荐温度范围为 2.0-4.0。超过 4.0 的值可能会失去精度。

训练时长

蒸馏通常受益于更长的训练:
params = LLMTrainingParams(
    ...
    epochs=5,  # 比标准微调更多的轮次
    lr=1e-4,   # 略高的学习率
)

示例:API 助手

蒸馏大型模型的 API 知识:
params = LLMTrainingParams(
    model="google/gemma-3-270m",
    teacher_model="meta-llama/Llama-3.2-8B",
    data_path="./api_prompts.jsonl",
    project_name="api-assistant-small",

    use_distillation=True,
    distill_temperature=3.0,
    distill_alpha=0.7,

    epochs=10,
    batch_size=8,
    lr=5e-5,
    peft=True,
    lora_r=32,
)

对比

无蒸馏

# 在 270M 模型上的标准微调
aitraining llm --train \
  --model google/gemma-3-270m \
  --data-path ./data.jsonl \
  --project-name standard-model

有蒸馏

# 从 2B 模型蒸馏
aitraining llm --train \
  --model google/gemma-3-270m \
  --teacher-model google/gemma-2-2b \
  --data-path ./data.jsonl \
  --project-name distilled-model \
  --use-distillation
蒸馏模型通常在复杂任务上表现更好。

使用场景

  • 部署: 为生产环境创建快速模型
  • 边缘设备: 在移动/嵌入式系统上运行
  • 成本降低: 降低推理成本
  • 专业化: 将大型模型知识聚焦到特定领域

下一步