其他
漫谈度量学习(Distance Metric Learning)
1 前言
2 Distance
3 Distance metric learning 与 LMNN
这篇文章除了作为 distance metric learning 的开山之作以外,还创造性的将 metric learning 的过程构造成了一个凸优化问题,这给当时初生的 metric learning 的可靠性和稳定性提供了很大的保障和帮助。在这篇文章中,他们定义 distance metric 并且从 side-information 中学习出 metric 的做法在今天其实可以被称为 mahalanobis distance metric learning。顾名思义,其原因是直接借用了马氏距离中度量距离的方法,其距离可写成此形式:,其中 是一个对称的半正定矩阵,而在马氏距离中 是协方差矩阵的逆 。这种构造方法还可以从另外一个角度来理解,由于 是一个对称的半正定矩阵,则有:
很显然,这实际上就是对原样本空间利用矩阵 进行线性投影后在新样本空间内的欧氏距离!而 distance metric learning 的目标,就是学出这样的矩阵 (或矩阵 )。到这里可能有伙伴会好奇,为什么不在原样本空间用欧氏距离,一定要经过一个线性投影 呢?答案很简单,在原样本空间里我们无法通过欧氏距离来很好的区分不同类别的样本,原样本空间中不同类别的样本混在一起,导致其欧氏距离很小,如果此时我们用一些 metric-based 的分类器效果就很差。
为了提高分类准确率,我们对原样本空间做线性投影后,在新的空间可能不同类别的样本就会分得比较开,此时用”欧氏距离“(投影后的新空间内)效果就会比较好。下图一看便知:
当然,不是所有的 distance metric 都是线性的,有线性的、也有非线性的,有监督的、也有非监督的。本文只为简单介绍 distance metric learning 的概念,更多的内容留待感兴趣的伙伴自由探索。在过去二十年 CS、AI 领域的研究中,涌现出了众多 distance metric learning 的经典方法,例如 NCA、LMNN、ITML 等,还有后续的一系列 online learning、linear similarity learning、nonlinear metric learning、multi-task metric learning、deep distance metric learning 等等。20 年前的方兴未艾,20 年后的浩如烟海,大致就是如此了。
在此我以我个人比较熟悉的 LMNN 方法为例,给大家分享一下该方法中 metric learning 的思想。LMNN 是 Large Margin Nearest Neighbor的缩写,原文最初发表于 2005 年的 NIPS,后来 2009 年在 JMLR 上补发了长文,逻辑很清晰、也很易懂,提纲挈领的同时又不失细节,算得上小半篇 review,强烈推荐:
LMNN 方法的核心思想也比较简单:在 LMNN 之前的很多方法(如 Eric Xing 等人的 MMC)约束都很强,要求同类别样本之间差距小而不同类别样本之间差距大,这在实际中是很难实现的。而且结合最近邻算法的思想,影响分类结果的就是最近的几个样本的类别,没有必要去要求那么强的约束,我们只要实现利用 metric 把离每个样本最近的几个样本都拉近到同类别,不就可以了吗?LMNN 的思想可以用下图来理解:
▲ Weinberger K Q, Saul L K. Distance metric learning for large margin nearest neighbor classification[J]. Journal of machine learning research, 2009, 10(2).
那么,LMNN 是如何实现这一推一拉的呢?这得益于 LMNN 算法中巧妙的 loss function设计。LMNN 算法的 loss function 分为两个部分,一个部分负责拉近最近的数个同类别样本(target neighbors),一部分负责推开每个样本的入侵者(impostors)。
负责拉近的部分被构造为:
很显然,这部分通过惩罚同类别样本间的大距离达到拉近同类别样本的效果。负责推开的部分被构造为:
其中 代表标准的 hinge loss。如果 是 的 target neighbors,,那自然就不用推开了,此时这一项 loss 为 0;如果 是 的 impostors,,如果括号内大于 0,意味着在投影后的新空间中,仍然有 impostors 入侵到 margin 之内。所以此项 loss 通过惩罚括号内大于 0 的项(impostors 入侵进入 margin)来实现推开的 impostors 的功能。
最终这一推一拉利用一个参数 实现了平衡和调整:
4 Distance metric learning 在工程应用中的实践
为什么做:在智能制造的大背景下,随着传感器、控制、计算机等一系列技术的发展,各种过去无法被记录的制造过程中的数据现在可以被各种各样的传感器记录下来了,这给我们提供了很多数据驱动的制造过程分析和决策的机会,制造产品的质量预测与监控就是其中重要的一个分支。
过去传统的质量控制大多都是 post-process 的,也就是说得等产品完成很多工序直到质检工序时才被发现,而某些产品质量可能在其中某一个工序就出问题了,可这些有问题的产品依旧经过了后续的很多工序加工,这造成了很大的资源和成本浪费。所以我们需要一个 in-process 的质量监测工具帮助我们根据制造过程数据及时判断出哪些产品可能是有问题的、及时发现并防止浪费。
做什么:在某个特殊的产品制造过程中,行业专家告诉我们制造过程中的某种数据(记为 )是跟产品的质量有重要联系的。于是,我们借由传感器收集到了多个样本制造过程种每个样本的 数据,根据现有制造产线后续的质检工序,制造方为我们提供了这些样本的真实标签(记为 ,两种:合格品 和不合格品 )。
这个问题中有两个小的 challenges:由于制造过程中的随机性,每个样本的 数据其长度大多是不相等的;此外,我们收集数据的样本中,样本量较为有限,且合格品的数量远大于不合格品的数量,比率大概在 20:1 左右,换句话说,这是个样本量有限的不平衡数据集。而我们的目标就是根据收集到的 数据来构建一个预测产品质量是否合格(标签 )的模型,这样就可以针对每个制造过程中的新样本根据收集到的过程数据实现 in-process 的质量预测和监控了。
怎么做:前期我做了其他方法的很多探索,就不在此一一赘述了,此处直接阐述利用 distance metric learning 的做法。在将样本简单划分为 training set 和 test set 后,由于 distance metric learning 不能处理非等长的数据,我先随机选取了一个样本作为参考样本利用 DTW 将其他所有样本(包括 training set 和 test set)进行拉齐处理(这里有一个 trick:利用各种 MATLAB 或 python 内嵌的包确实可以根据 DTW 拉齐两个不等长的样本,但拉齐后的长度跟两个样本的原长度都不相等,所以我后面自己写了一段代码解决了),这样就得到了等长的数据集。如下所示:
尝试过多种 metric learning 方法后,我最终使用了经典的 LMNN 方法在 training set 上学习了特定的 distance metric,并将其应用到 test set 上,结合 1NN 分类器最终输出预测的样本标签 。由于样本量有限,我对整个数据集做了 100 次 random split,最后取平均的 performance。以上就是我之前所做的对于不等长的数据集如何进行 distance metric learning 的基本思路。
但是当时我觉得 performance 还是不够好,于是引入了另一个基于 DTW 拉齐处理的 trick:在拉齐过程中,我针对每一个 split 下的 training set 和 test set 利用多个样本 作为参考样本来进行拉齐,这样经过 DTW 的处理,每一个 split 下我就有了 个不同长度的数据集,根据这 个不同长度的数据集可以用 metric learning 学习出 个不同的 distance metric,再构建 个 1NN 分类器并将它们 ensembling 在一起根据 个分类结果利用 majority voting 的方法最终得出预测标签。这一 trick 大大的帮助我在这一问题上提升了预测的 performance。
结果:在此我比较了三种方法:1、直接利用 DTW 距离结合 1NN 进行分类;2、先利用 DTW 拉齐、LMNN 学习 distance metric、1NN 分类;3、先利用 DTW 基于多个参考样本拉齐、LMNN 学习多个 distance metrics、多个 1NN 分类最终 majority voting。结果如下:在 accuracy 方面三种方法都很高,但这是由之前提到的不平衡数据所造成的,而且可以由 precision, recall, F1-score 三项很明显的看出来:DTW+1NN 对于我们最关注的不合格品的分类准确度并不高。利用 distance metric learning 的确进一步改进了对不合格品的分类准确度,但最终引入的 ensembling 的 trick 帮助我们达到了整体大于 99.4% 的准确度,而且对不合格品的分类准确度超过了95%。
事实上,在解决这个问题的过程中,我尝试了很多很多方法,也曾经试图对 LMNN 的模型结构做更改,但效果都不好,最后还是老老实实把 distance metric learning 单纯的当作工具来使用。没能在理论方面做出啥贡献确实汗颜,但我想这个 toy example 的用处也正是跟大家分享 distance metric learning 的用法和其中遇到的坑和可能的解决办法,倒也恰如其分。请大佬们轻拍~
5 总结
©作者 | 黄春喜
原文| https://zhuanlan.zhihu.com/p/458114525