知识蒸馏
训练更小、更快的模型,使其模仿更大的教师模型的行为。
什么是蒸馏?
知识蒸馏将知识从大型”教师”模型转移到较小的”学生”模型。学生学会产生与教师相似的输出,获得仅从数据中无法学到的能力。
快速开始
aitraining llm --train \
--model google/gemma-3-270m \
--teacher-model google/gemma-2-2b \
--data-path ./prompts.jsonl \
--project-name distilled-model \
--use-distillation \
--distill-temperature 3.0
Python API
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject
params = LLMTrainingParams(
# Student model (smaller)
model="google/gemma-3-270m",
# Teacher model (larger)
teacher_model="google/gemma-2-2b",
# Data
data_path="./prompts.jsonl",
project_name="distilled-gemma",
# Enable distillation
use_distillation=True,
distill_temperature=3.0, # Default: 3.0
distill_alpha=0.7, # Default: 0.7
distill_max_teacher_length=512, # Default: 512
# Training
trainer="sft",
epochs=5,
batch_size=4,
lr=1e-4,
)
project = AutoTrainProject(params=params, backend="local", process=True)
project.create()
| 参数 | 描述 | 默认值 |
|---|
use_distillation | 启用蒸馏 | False |
teacher_model | 教师模型路径 | 当 use_distillation=True 时必需 |
distill_temperature | Softmax 温度(推荐 2.0-4.0) | 3.0 |
distill_alpha | 蒸馏损失权重 | 0.7 |
distill_max_teacher_length | 教师模型最大 token 数 | 512 |
teacher_prompt_template | 教师模型提示词模板 | None |
student_prompt_template | 学生模型提示词模板 | "{input}" |
更高的温度使教师模型的概率分布更柔和,使学生更容易学习:
1.0: 正常概率
2.0-4.0: 更柔和,更易教学(推荐)
>4.0: 非常柔和,可能失去精度
Alpha
控制蒸馏和标准损失之间的平衡:
0.0: 仅标准损失(无蒸馏)
0.5: 平衡
0.7: 默认值(更重视蒸馏)
1.0: 仅蒸馏损失
提示词模板
自定义教师和学生模型的提示词格式化方式:
params = LLMTrainingParams(
...
use_distillation=True,
teacher_prompt_template="<|system|>You are helpful.<|user|>{input}<|assistant|>",
student_prompt_template="{input}",
)
使用 {input} 作为实际提示词文本的占位符。
数据格式
简单的提示词对蒸馏效果很好:
{"text": "What is machine learning?"}
{"text": "Explain how neural networks work."}
{"text": "Write a function to sort a list in Python."}
或包含期望输出:
{"prompt": "What is AI?", "response": "..."}
最佳实践
明智选择模型
- 教师模型应该显著更大(4倍+ 参数)
- 相同架构族通常效果最好
- 教师模型应该在目标任务上表现良好
温度调优
# 保守(教师不确定)
distill_temperature=2.0
# 标准(大多数情况,默认)
distill_temperature=3.0
# 激进(教师自信)- 在推荐范围上限
distill_temperature=4.0
推荐温度范围为 2.0-4.0。超过 4.0 的值可能会失去精度。
训练时长
蒸馏通常受益于更长的训练:
params = LLMTrainingParams(
...
epochs=5, # 比标准微调更多的轮次
lr=1e-4, # 略高的学习率
)
示例:API 助手
蒸馏大型模型的 API 知识:
params = LLMTrainingParams(
model="google/gemma-3-270m",
teacher_model="meta-llama/Llama-3.2-8B",
data_path="./api_prompts.jsonl",
project_name="api-assistant-small",
use_distillation=True,
distill_temperature=3.0,
distill_alpha=0.7,
epochs=10,
batch_size=8,
lr=5e-5,
peft=True,
lora_r=32,
)
无蒸馏
# 在 270M 模型上的标准微调
aitraining llm --train \
--model google/gemma-3-270m \
--data-path ./data.jsonl \
--project-name standard-model
有蒸馏
# 从 2B 模型蒸馏
aitraining llm --train \
--model google/gemma-3-270m \
--teacher-model google/gemma-2-2b \
--data-path ./data.jsonl \
--project-name distilled-model \
--use-distillation
蒸馏模型通常在复杂任务上表现更好。
使用场景
- 部署: 为生产环境创建快速模型
- 边缘设备: 在移动/嵌入式系统上运行
- 成本降低: 降低推理成本
- 专业化: 将大型模型知识聚焦到特定领域
下一步