Skip to main content

Knowledge Distillation

Train smaller, faster models that mimic the behavior of larger teacher models.

What is Distillation?

Knowledge distillation transfers knowledge from a large “teacher” model to a smaller “student” model. The student learns to produce similar outputs to the teacher, gaining capabilities beyond what it could learn from data alone.

Quick Start

aitraining llm --train \
  --model google/gemma-3-270m \
  --teacher-model google/gemma-2-2b \
  --data-path ./prompts.jsonl \
  --project-name distilled-model \
  --use-distillation \
  --distill-temperature 3.0

Python API

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

params = LLMTrainingParams(
    # Student model (smaller)
    model="google/gemma-3-270m",

    # Teacher model (larger)
    teacher_model="google/gemma-2-2b",

    # Data
    data_path="./prompts.jsonl",
    project_name="distilled-gemma",

    # Enable distillation
    use_distillation=True,
    distill_temperature=3.0,   # Default: 3.0
    distill_alpha=0.7,         # Default: 0.7
    distill_max_teacher_length=512,  # Default: 512

    # Training
    trainer="sft",
    epochs=5,
    batch_size=4,
    lr=1e-4,
)

project = AutoTrainProject(params=params, backend="local", process=True)
project.create()

Parameters

ParameterDescriptionDefault
use_distillationEnable distillationFalse
teacher_modelPath to teacher modelRequired when use_distillation=True
distill_temperatureSoftmax temperature (2.0-4.0 recommended)3.0
distill_alphaDistillation loss weight0.7
distill_max_teacher_lengthMax tokens for teacher512
teacher_prompt_templateTemplate for teacher promptsNone
student_prompt_templateTemplate for student prompts"{input}"

Temperature

Higher temperature makes the teacher’s probability distribution softer, making it easier for the student to learn:
  • 1.0: Normal probabilities
  • 2.0-4.0: Softer, more teachable (recommended)
  • >4.0: Very soft, may lose precision

Alpha

Controls balance between distillation and standard loss:
  • 0.0: Only standard loss (no distillation)
  • 0.5: Equal balance
  • 0.7: Default (more weight on distillation)
  • 1.0: Only distillation loss

Prompt Templates

Customize how prompts are formatted for teacher and student models:
params = LLMTrainingParams(
    ...
    use_distillation=True,
    teacher_prompt_template="<|system|>You are helpful.<|user|>{input}<|assistant|>",
    student_prompt_template="{input}",
)
Use {input} as the placeholder for the actual prompt text.

Data Format

Simple prompts work well for distillation:
{"text": "What is machine learning?"}
{"text": "Explain how neural networks work."}
{"text": "Write a function to sort a list in Python."}
Or with expected outputs:
{"prompt": "What is AI?", "response": "..."}

Best Practices

Choose Models Wisely

  • Teacher should be significantly larger (4x+ parameters)
  • Same architecture family often works best
  • Teacher should be capable at the target task

Temperature Tuning

# Conservative (teacher is uncertain)
distill_temperature=2.0

# Standard (most cases, default)
distill_temperature=3.0

# Aggressive (teacher is confident) - at upper recommended range
distill_temperature=4.0
Recommended temperature range is 2.0-4.0. Values above 4.0 may lose precision.

Training Duration

Distillation often benefits from longer training:
params = LLMTrainingParams(
    ...
    epochs=5,  # More epochs than standard fine-tuning
    lr=1e-4,   # Slightly higher learning rate
)

Example: API Assistant

Distill a large model’s API knowledge:
params = LLMTrainingParams(
    model="google/gemma-3-270m",
    teacher_model="meta-llama/Llama-3.2-8B",
    data_path="./api_prompts.jsonl",
    project_name="api-assistant-small",

    use_distillation=True,
    distill_temperature=3.0,
    distill_alpha=0.7,

    epochs=10,
    batch_size=8,
    lr=5e-5,
    peft=True,
    lora_r=32,
)

Comparison

Without Distillation

# Standard fine-tuning on 270M model
aitraining llm --train \
  --model google/gemma-3-270m \
  --data-path ./data.jsonl \
  --project-name standard-model

With Distillation

# Distillation from 2B model
aitraining llm --train \
  --model google/gemma-3-270m \
  --teacher-model google/gemma-2-2b \
  --data-path ./data.jsonl \
  --project-name distilled-model \
  --use-distillation
The distilled model typically performs better, especially on complex tasks.

Use Cases

  • Deployment: Create fast models for production
  • Edge devices: Run on mobile/embedded systems
  • Cost reduction: Lower inference costs
  • Specialization: Focus large model knowledge on specific domain

Next Steps