一、从DeepSeek出圈说起

最近国产大模型DeepSeek(深度求索)突然爆火,开源模型在多项评测中超越O1的表现令人惊叹。但更值得关注的是,其实很多我们熟知的大模型这都是知识蒸馏技术的杰作哦,然后我就好奇到底怎么实现的模型蒸馏,上网查阅之后发现没有比较好的代码教学,那么这里我就写一个简单的快速体验蒸馏模型威力的代码示例供大家交流学习讨论。

二、知识蒸馏原理(小学生都能懂版)

1. 核心思想

就像学霸把解题思路教给学弟:

  • 教师网络(大模型):知识渊博但行动缓慢的学霸
  • 学生网络(小模型):需要快速解题的普通学生
  • 蒸馏过程:学霸不直接给答案,而是传授解题技巧

2. 关键技术点

概念生活比喻技术作用
温度参数(T)放大镜看细节让概率分布更"柔软"易学习
KL散度损失模仿解题步骤的相似度打分衡量学生与学霸输出的差异
软标签学霸的解题笔记比硬标签包含更多信息

3.蒸馏程度划分

三、代码实战:手把手实现蒸馏

1. 环境准备

pip install torch numpy matplotlib

2. 核心代码解析(简化版)

# 教师模型结构(学霸的大脑)实战中会替换为真正的LLM,这里做演示用
class TeacherNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 512),  # 更多神经元
            nn.ReLU(),
            nn.Linear(512, 256),  # 更深层次
            nn.ReLU(),
            nn.Linear(256, 10)    # 最终输出
        )

# 学生模型结构(普通学生)
class StudentNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 64),   # 更少的参数
            nn.ReLU(),
            nn.Linear(64, 10)     # 更浅的结构
        )

3. 知识传递的关键代码

# 温度就像"知识放大镜"
def distillation_loss(student_out, teacher_out, T=2.0):
    # 软化输出:让1+1=2变成1+1≈2.01的过程
    soft_teacher = nn.Softmax(dim=1)(teacher_out/T)
    log_soft_student = nn.LogSoftmax(dim=1)(student_out/T)
    
    # 计算知识差异(KL散度)
    return nn.KLDivLoss()(log_soft_student, soft_teacher) * T**2

4. 训练过程(师生互动)

for epoch in range(100):
    # 教师生成"解题思路"
    with torch.no_grad():
        teacher_notes = teacher(problems)
    
    # 学生边看笔记边学习
    student_optimizer.zero_grad()
    student_answers = student(problems)
    loss = distillation_loss(student_answers, teacher_notes)
    loss.backward()
    student_optimizer.step()

四、DeepSeek蒸馏其他模型的成功秘诀

  1. 分层蒸馏:先让中等模型向超大模型学习,再让小模型向中等模型学习
  2. 课程学习:先学简单题目(清晰样本),再挑战难题(模糊样本)
  3. 数据增强:给题目加"干扰项"(噪声数据)提升鲁棒性

五、自己动手试试看

import torch
import torch.nn as nn
import torch.optim as optim

# ---------------------------
# (1)定义一个较大的 Teacher 模型
# ---------------------------
class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        # 这里随意搭建一个相对稍大一点的模型
        self.net = nn.Sequential(
            nn.Linear(784, 512),  # 更多神经元
            nn.ReLU(),
            nn.Linear(512, 256),  # 更深层次
            nn.ReLU(),
            nn.Linear(256, 10)    # 最终输出
        )

    def forward(self, x):
        return self.net(x)

# ---------------------------
# (2)定义一个较小的 Student 模型
# ---------------------------
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        # 小模型:层数更少 or 参数更少
        self.net = nn.Sequential(
            nn.Linear(784, 64),   # 更少的参数
            nn.ReLU(),
            nn.Linear(64, 10)     # 更浅的结构
        )

    def forward(self, x):
        return self.net(x)

# ---------------------------
# (3)知识蒸馏的损失函数
# ---------------------------
def distillation_loss(student_outputs, teacher_outputs, temperature=2.0, alpha=0.5):
    """
    student_outputs: 学生网络的输出
    teacher_outputs: 教师网络的输出
    temperature:     温度系数,越大越"平滑"
    alpha:           独立蒸馏损失与真实标签交叉熵损失之间的平衡
    """
    # 使用 KLDivLoss 计算的蒸馏损失
    # log_softmax 前需要先 / temperature
    student_log_probs = nn.LogSoftmax(dim=1)(student_outputs / temperature)
    teacher_probs = nn.Softmax(dim=1)(teacher_outputs / temperature)
    distill_loss = nn.KLDivLoss(reduction='batchmean')(student_log_probs, teacher_probs) * (temperature ** 2)

    # 这里简化,仅用蒸馏损失,并没有结合真实标签损失以示例
    # 若有真实标签可以按照 alpha * distill_loss + (1 - alpha) * ce_loss 的形式结合
    return distill_loss

def train(model, optimizer, data, teacher_outputs=None):
    """
    model:         待训练的模型
    optimizer:     优化器
    data:          本示例中用随机生成的数据 (inputs, labels)
    teacher_outputs: 用于蒸馏的小批量 teacher 输出(小模型训练时用)
    """
    inputs, labels = data
    outputs = model(inputs)

    if teacher_outputs is not None:
        # 学生模型进行蒸馏
        loss = distillation_loss(outputs, teacher_outputs)
    else:
        # 教师模型普通训练(交叉熵)
        ce_loss_fn = nn.CrossEntropyLoss()
        loss = ce_loss_fn(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

def main():
    # ---------------------------
    # 1. 创建 Teacher 和 Student 网络
    # ---------------------------
    teacher = TeacherNet()
    student = StudentNet()

    # ---------------------------
    # 2. 优化器定义
    # ---------------------------
    teacher_optimizer = optim.Adam(teacher.parameters(), lr=1e-3)
    student_optimizer = optim.Adam(student.parameters(), lr=1e-3)

    # ---------------------------
    # 3. 创建一些简单的随机数据来演示
    #    假设我们每次训练有32张图片,每张图片28*28=784维度
    # ---------------------------
    EPOCHS = 100
    BATCH_SIZE = 32
    for epoch in range(EPOCHS):
        # 演示 teacher 的训练
        teacher_inputs = torch.randn(BATCH_SIZE, 784)   # 生成随机输入
        teacher_labels = torch.randint(0, 10, (BATCH_SIZE,))  # 生成随机标签(0-9)
        teacher_loss = train(teacher, teacher_optimizer, (teacher_inputs, teacher_labels))

        # 生成学生训练数据
        student_inputs = torch.randn(BATCH_SIZE, 784)
        # teacher 输出作为 student 的蒸馏目标
        with torch.no_grad():
            teacher_preds = teacher(student_inputs)
        # 这里 labels 仅做演示,用不到
        student_labels = torch.randint(0, 10, (BATCH_SIZE,))

        student_loss = train(student, student_optimizer, (student_inputs, student_labels), teacher_preds)

        # 打印各自的 loss
        print(f"Epoch {epoch+1}/{EPOCHS}, Teacher Loss: {teacher_loss:.4f}, Student Loss (Distill): {student_loss:.4f}")

if __name__ == "__main__":
    main()

可以看到student的Loss是有显著降低的哈:

六、为什么这很重要?

  1. 手机端应用:1.3B参数的小模型能在端侧上实现20 tokens/秒的生成速度
  2. 成本降低:推理成本仅为原模型的1/100
  3. 隐私保护:本地运行无需上传数据

七、延伸思考

  • 如果教师教错了怎么办?(对抗蒸馏)
  • 如何让多个教师共同指导一个学生?(集成蒸馏)
  • 学生能否在某些方面超越老师?(逆蒸馏)

附录:常见问题解答

Q:和直接训练小模型有什么区别?
A:就像抄答案 vs 理解解题思路,蒸馏后的模型面对新问题表现更好

Q:温度参数是不是越大越好?
A:温度太高会像过度放大的地图,失去关键细节,一般2-5之间效果最佳