说明:收录全网最新的团体标准 提供单次或批量下载
(19)中华 人民共和国 国家知识产权局 (12)发明 专利申请 (10)申请公布号 (43)申请公布日 (21)申请 号 202111641566.9 (22)申请日 2021.12.2 9 (71)申请人 成都晓多科技有限公司 地址 610000 四川省成 都市天府新区兴隆 街道湖畔路西段123号 (72)发明人 郭湘 黄鹏 江岭  (74)专利代理 机构 成都睿道专利代理事务所 (普通合伙) 51217 代理人 薛波 (51)Int.Cl. G06F 16/33(2019.01) G06F 17/11(2006.01) G06F 17/16(2006.01) G06K 9/62(2022.01) (54)发明名称 一种模型蒸馏方法、 系统以及文本 检索方法 (57)摘要 本发明提供了一种模 型蒸馏方法、 系统以及 文本检索方法, 其模型蒸馏方法包括如下步骤: S1.构建一个学生模型以及多个教师预训练模 型, 采用对比学习方法对多个教师预训练模型以 及学生模型进行训练, 在对比学习训练中, 获取 平均教师模型相似度矩 阵以及学生模型相似度 矩阵; S2.获取预蒸馏学生模型的第一输出以及 多个教师模型的平均输出; S3.获取集成蒸馏损 失函数, 并根据集成蒸馏损失函数训练所述预蒸 馏学生模型, 并得到蒸馏学生模型; 本发明通过 对比学习训练得到平均教师模型相似度矩 阵与 学生模型相似度矩阵, 将其加入集成蒸馏损失函 数中, 通过集成蒸馏损失函数训练预蒸馏学生模 型, 使得学生模 型能够更好的拟合教师模型的检 索能力。 权利要求书2页 说明书6页 附图1页 CN 114328834 A 2022.04.12 CN 114328834 A 1.一种模型蒸馏方法, 其特 征在于, 包括如下步骤: S1.构建一个学生模型以及多个教师预训练模型, 采用对比学习方法对多个教师预训 练模型进 行训练, 训练完成得到多个教师模型; 采用对比学习方法对 学生模型进 行训练, 训 练得到预蒸馏学生模型, 在对比学习训练中, 获取多个教师预训练模型 的平均教师模型相 似度矩阵, 获取 学生模型的学生模型相似度矩阵; S2.获取预蒸馏学生模型的第一输出以及多个教师模型的第二输出; 对多个教师模型 的第二输出求取平均, 得到平均输出; S3.根据平均教师模型相似度矩阵、 学生模型相似度矩阵、 学生模型的第一输出、 多个 教师模型的平均输出以及样本数据的真实标签获取集成蒸馏损失函数, 并根据集成蒸馏损 失函数训练所述预蒸馏学生模型, 并得到蒸馏学生模型。 2.根据权利要求1所述的基于对比学习与集成蒸馏的文本检索方法, 其特征在于, S1具 体为: 获取教师预训练模型以及学生模型 的训练样本数据集, 多个教师预训练的训练样本 数据集的数据量相同; 分别对学生模型以及多个教师预训练模型进行对比学习训练, 其训 练方式如下: A.在一个batch内, 选定输入样本, 分别两次向教师预训练模型或学生模型输入输入样 本, 通过dropout方法得到 输入样本的正样本对; 计算 正样本对的余弦相似度; B.随机采样输入样本所在的同一batc h内的其他输入样本作为负 样本; C.根据正样本对的余弦相似度以及负样本的余弦相似度设计对比学习损失函数, 根据 对比学习损失函数训练教师 预训练模型以及学生模型。 3.根据权利要求2所述的模型蒸馏方法, 其特 征在于, 所述对比学习损失函数 具体为: 其中, sim()表示 余弦相似度函数, z是dropoutmask, z′是dropout mask; 以及 构成正样本对, 为负样本, τ为温度超 参数。 4.根据权利要求1所述的基于对比学习与集成蒸馏的文本检索方法, 其特征在于, 所述 集成蒸馏损失函数 具体为: Loss=β *((1‑α )*loss1+α *loss2)+(1‑β )*batch_loss 其中, loss1=CE(q, y); loss2=CE(p, q); batch_loss=MSE(M_t, M_s), 上式中, q为学生模型的输出logits, p为多个教师模型输出的logits的平均logits, y 为样本数据的真正标签, M_t为平均 教师模型相 似度矩阵, M_s为学生模型相 似度矩阵, α 以 及β 分别为超参数。 5.根据权利要求1所述的模型蒸馏方法, 其特征在于, 其中获取平均教师相似度矩阵包权 利 要 求 书 1/2 页 2 CN 114328834 A 2括如下步骤: A.获取同一个batch两次输入一个教师预训练模型的输出, 分别为第一向量矩阵以及 第二向量矩阵, 向量矩阵用(batc h_size, dimensi on)表示; B.计算第一向量矩阵以及第二向量矩阵的相似度, 得到教师模型相似度矩阵, 表示为 (batch_size, batc h_size); C.计算多个教师模型余弦相似度矩阵的平均, 得到平均教师模型相似度矩阵; 其中batc h_size为batc h的大小, dimensi on为教师预训练模型输出的向量维度。 6.根据权利要求1所述的模型蒸馏方法, 其特征在于, 所述学生模型包括多层全连接 层, 多层所述全连接层用于对齐教师模型的输出维度。 7.一种文本检索方法, 其特征在于, 应用如权利要求1 ‑6任意一项所述模型蒸馏方法获 得蒸馏学生模型, 包括如下步骤: S1.获取查询文本以及候选文本; S2.将所述查询文本以及候选文本输入蒸馏学生模型, 获得蒸馏学生模型输出的所述 查询文本与所述 候选文本的匹配率。 8.一种模型蒸馏系统, 其特征在于, 应用如权利要求1 ‑6任意一项所述的模型蒸馏方 法, 包括: 获取模块, 所述获取模块用于获取教师 预训练模型以及学生模型的训练样本数据集; 对比学习 模块, 所述对比学习 模块用于构建教师预训练模型或学生模型的正样本对以 及负样本, 设计对比学习损失函数, 训练教师 预训练模型或学生模型; 蒸馏模块, 用于获取教师模型的输出、 预蒸馏学生模型的输出、 学生模型相似度矩阵以 及平均教师模型相似度矩阵, 设计集成蒸馏损失函数, 根据集成蒸馏损失函数训练学生模 型。权 利 要 求 书 2/2 页 3 CN 114328834 A 3

.PDF文档 专利 一种模型蒸馏方法、系统以及文本检索方法

文档预览
中文文档 10 页 50 下载 1000 浏览 0 评论 309 收藏 3.0分
温馨提示:本文档共10页,可预览 3 页,如浏览全部内容或当前文档出现乱码,可开通会员下载原始文档
专利 一种模型蒸馏方法、系统以及文本检索方法 第 1 页 专利 一种模型蒸馏方法、系统以及文本检索方法 第 2 页 专利 一种模型蒸馏方法、系统以及文本检索方法 第 3 页
下载文档到电脑,方便使用
本文档由 人生无常 于 2024-03-18 20:42:20上传分享
友情链接
站内资源均来自网友分享或网络收集整理,若无意中侵犯到您的权利,敬请联系我们微信(点击查看客服),我们将及时删除相关资源。