概念区分:模型蒸馏(M knowledge Distillation)和模型抽取(Model Extraction)都是"复制"模型,但目的和手段不同。蒸馏是正向的知识迁移,用于模型压缩和部署;抽取是逆向的攻击窃取,用于盗取模型功能。
在深入模型蒸馏和抽取之前,需要理解模型安全中的几个核心概念。
模型蒸馏是一种模型压缩技术,通过让小模型(学生)学习大模型(老师)的"软标签"来获得接近大模型的性能。
大模型(Teacher)指导 → 小模型(Student)学习软标签
# 模型蒸馏示例
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, labels, T=4, alpha=0.7):
# 软标签损失:学生学习老师的概率分布
soft_teacher = F.softmax(teacher_logits / T, dim=-1)
soft_student = F.log_softmax(student_logits / T, dim=-1)
soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T * T)
# 硬标签损失:学生学习真实标签
hard_loss = F.cross_entropy(student_logits, labels)
# 组合损失
return alpha * soft_loss + (1 - alpha) * hard_loss
# 训练学生模型
for epoch in range(num_epochs):
for batch in dataloader:
inputs, labels = batch
teacher_outputs = teacher_model(inputs)
student_outputs = student_model(inputs)
loss = distillation_loss(student_outputs, teacher_outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
模型抽取是一种攻击技术,攻击者通过查询目标模型来窃取其功能,构建一个功能相近的复制模型。
攻击者通过 API 查询 → 收集输入输出对 → 训练复制模型
| 维度 | 模型蒸馏 (正常) | 模型抽取 (攻击) |
|---|---|---|
| 目的 | 模型压缩、部署优化 | 窃取模型功能 |
| 发起者 | 模型拥有者 | 外部攻击者 |
| 信息访问 | 完全访问 | 只通过 API |
| 知识来源 | 软标签(概率分布) | 硬标签或概率 |
| 合法性 | 合法 | 非法 |
攻击实例:某公司提供文本分类 API,攻击者通过大量查询收集输入输出对,然后用这些数据训练自己的模型。如果成功复制,不仅可以免费使用原模型功能,还可能进一步利用复制模型进行迁移攻击(Adversarial Attack)。