Model Distillation & Extraction

概念区分:模型蒸馏(M knowledge Distillation)和模型抽取(Model Extraction)都是"复制"模型,但目的和手段不同。蒸馏是正向的知识迁移,用于模型压缩和部署;抽取是逆向的攻击窃取,用于盗取模型功能。

Understanding (理解) - 模型安全基础知识

在深入模型蒸馏和抽取之前,需要理解模型安全中的几个核心概念。

Model Distillation (模型蒸馏)

模型蒸馏是一种模型压缩技术,通过让小模型(学生)学习大模型(老师)的"软标签"来获得接近大模型的性能。

📚 模型蒸馏动画演示
Teacher 🧠 大模型
Student 🧠 小模型

大模型(Teacher)指导 → 小模型(Student)学习软标签

软标签分布:
A
B
C
# 模型蒸馏示例
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, T=4, alpha=0.7):
    # 软标签损失:学生学习老师的概率分布
    soft_teacher = F.softmax(teacher_logits / T, dim=-1)
    soft_student = F.log_softmax(student_logits / T, dim=-1)
    soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T * T)
    
    # 硬标签损失:学生学习真实标签
    hard_loss = F.cross_entropy(student_logits, labels)
    
    # 组合损失
    return alpha * soft_loss + (1 - alpha) * hard_loss

# 训练学生模型
for epoch in range(num_epochs):
    for batch in dataloader:
        inputs, labels = batch
        teacher_outputs = teacher_model(inputs)
        student_outputs = student_model(inputs)
        
        loss = distillation_loss(student_outputs, teacher_outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Model Extraction (模型抽取/窃取)

模型抽取是一种攻击技术,攻击者通过查询目标模型来窃取其功能,构建一个功能相近的复制模型。

🔓 模型抽取攻击动画演示
🔍
查询API
💾
收集输出
📋
训练副本

攻击者通过 API 查询 → 收集输入输出对 → 训练复制模型

维度 模型蒸馏 (正常) 模型抽取 (攻击)
目的 模型压缩、部署优化 窃取模型功能
发起者 模型拥有者 外部攻击者
信息访问 完全访问 只通过 API
知识来源 软标签(概率分布) 硬标签或概率
合法性 合法 非法

攻击实例:某公司提供文本分类 API,攻击者通过大量查询收集输入输出对,然后用这些数据训练自己的模型。如果成功复制,不仅可以免费使用原模型功能,还可能进一步利用复制模型进行迁移攻击(Adversarial Attack)。

攻击方法详解

防御方法

📚 本章复习要点