Distilling the Knowledge in a Neural Network 知识蒸馏

张开发
2026/5/22 8:45:17 15 分钟阅读
Distilling the Knowledge in a Neural Network 知识蒸馏
之前看论文的时候一直有学生模型教师模型然后前天刷到:这种叫做知识蒸馏所以来了解一下什么是知识蒸馏目录结合硬目标的改进版更常用一、核心定义二、逐点拆解1. Bagging装袋法2. Boosting提升法三、关键对比一张表记全四、使用时机数学原理简单理解就是假如一组数据使用不同的模型训练然后再把他们类似于加一起然后平均这种思考这样他的泛化性啥的会更优但是在于移动端比如手机平板这些没有所谓的高级显卡即硬件条件有限那么该怎么办呢于是引入数据蒸馏下面需要谨记:Logits是softmax的输入全连接层的输出核心思路硬目标Hard Target就是我们平时用的 one-hot 标签比如 MNIST 里数字 2 就是[0,0,1,0,...]它只告诉模型 “这是 2”但没说 “这个 2 长得更像 3 还是更像 7”。软目标Soft Target是大模型教师模型softmax 输出的概率分布比如[0.000001, 0.999998, 0.000001, ...]它包含了大模型学到的数据相似性信息。对于 MNIST 这种简单任务大模型几乎能 100% 猜对所以 softmax 输出里正确类别的概率接近 1其他类别概率接近 0比如 10⁻⁶、10⁻⁹。这些极小概率藏着宝贵知识比如 “这个 2 看起来更像 3那个 2 看起来更像 7”但在普通交叉熵损失里这些接近 0 的概率对损失影响微乎其微小模型学不到这些细节。前人方案直接用 logits 训练Caruana 等人的办法是不用 softmax 后的概率而是拿大模型最后一层的logitssoftmax 之前的原始输出当目标让小模型的 logits 去拟合大模型的 logits用 MSE 损失。这样能保留那些细微的相似性信息但不够通用。本文方案知识蒸馏软化目标作者提出更通用的知识蒸馏给大模型的 softmax 加一个温度参数 Tsoftmax(logits / T)提高 T 会让概率分布变 “软”原本接近 0 的小概率会被放大分布更平滑相似性信息就凸显出来了。训练小模型时用同样的高温 T 去拟合大模型的软目标这样小模型就能学到大模型的 “暗知识”Dark Knowledge。作者还指出当温度 T 趋近无穷大时拟合 logits 就是知识蒸馏的特殊情况。结合硬目标的改进版更常用当迁移集有部分 / 全部真实标签时用两个损失的加权平均训练效果更好软目标损失学生模型用高温 T的 softmax 输出和教师模型的软目标算交叉熵。作用学习大模型的 “暗知识”数据相似性。硬目标损失学生模型用温度 T1的 softmax 输出和真实标签算交叉熵。作用保证模型不会偏离正确分类方向做 “兜底”。权重设置通常给硬目标损失很低的权重比如 0.1 或更小软目标损失占主导。知识补充集成学习Ensemble Learning的两大核心范式一、核心定义方法核心思想人话解释Bagging并行训练 独立投票 降低方差找一群 “水平差不多的专家”各自独立判断最后少数服从多数避免单个专家的偏见Boosting串行训练 知错就改 降低偏差找一群 “新手”第一个新手犯错后第二个新手专门盯着第一个错的地方学依次迭代最后加权投票越往后的模型越专注难样本二、逐点拆解1. Bagging装袋法全称Bootstrap Aggregating自助聚合核心步骤采样对原始数据集做有放回抽样Bootstrap生成多个不同的子数据集比如 100 个子集训练用每个子数据集训练一个独立的基模型比如决策树所有模型并行训练互不影响融合分类任务用 “投票”回归任务用 “平均”。典型代表随机森林Random Forest—— 在 Bagging 基础上给每个决策树随机选特征进一步降低过拟合。适用场景解决过拟合降低方差适合高方差、易过拟合的模型比如单棵决策树。例子你要判断一张图是不是猫找 10 个独立的 AI每个 AI 看不同的猫图子集最后 9 个说 “是”1 个说 “不是”就判定是猫。2. Boosting提升法核心步骤初始化给所有样本赋相同权重训练第一个基模型比如弱决策树只能比随机猜好一点权重调整把第一个模型判错的样本权重调高让下一个模型重点学这些难样本串行训练训练第二个模型专门纠正第一个的错误重复迭代每一轮都聚焦上一轮的错误融合给每个模型赋权重表现好的模型权重高最后加权投票 / 平均。三、关键对比一张表记全维度BaggingBoosting训练顺序并行所有模型同时训串行模型按顺序训后一个依赖前一个样本使用有放回抽样各模型样本独立全样本训练通过权重聚焦难样本模型权重所有模型权重相等表现好的模型权重高核心目标降低方差防过拟合降低偏差提准确率过拟合风险低多模型平均稀释噪声高过度聚焦难样本易过拟合代表算法随机森林AdaBoost、GBDT、XGBoost四、使用时机什么时候用 Bagging当你的模型 “学太细”比如单棵决策树在训练集 100% 准测试集 50% 准用随机森林Bagging平均多个树的结果能显著提升测试集准确率。什么时候用 Boosting当你的模型 “学不会”比如单棵浅决策树训练集准确率只有 60%用 XGBoostBoosting迭代纠正错误能把准确率提到 90%。避坑提醒Boosting 对异常值敏感会把异常值当成 “难样本” 反复学习使用前要先做数据清洗Bagging 对异常值不敏感因为多模型平均会抵消异常值的影响。模型可解释性差(不知道哪个模型最有用)算力消耗大运算时间长模型的选择具有随机性不能确保是最佳组合模型的参数信息保留了模型学到的知识学习如何从输入向量映射到输出向量将图像分类为母牛的概率是将其分类为汽车的概率的10倍Hinton在其论文中首先描述的正是这种knowledge需要教师网络向学生中蒸馏。比如 [ cat ,dog ,car,cow]这种预测正确值【0.010.90.010.08】这种数学原理温度参数的数学原理温度参数T 是应用在 logits 上的而不是直接应用在概率上# 标准softmaxT1probs F.softmax(logits, dim-1)# 带温度的softmaxprobs F.softmax(logits / T, dim-1)详细解释1. 温度 1 时假设我们有logits: [5, 2, 1]- T1 : softmax([5, 2, 1]) → [0.953, 0.037, 0.010] (尖锐分布)- T2 : softmax([2.5, 1, 0.5]) → [0.786, 0.138, 0.076] (较平滑)- T5 : softmax([1, 0.4, 0.2]) → [0.554, 0.263, 0.183] (更平滑)结论 温度越高概率分布越平滑保留更多类别间的关系信息2. 温度 1 时同样的logits: [5, 2, 1]- T0.5 : softmax([10, 4, 2]) → [0.998, 0.002, 0.000] (更尖锐)- T0.1 : softmax([50, 20, 10]) → [1.000, 0.000, 0.000] (接近one-hot)结论 温度越低概率分布越尖锐接近one-hot编码

更多文章