查看原文
其他

【综述专栏】对比学习入门 A Primer on Contrastive Learning

在科学研究中,从方法论上来讲,都应“先见森林,再见树木”。当前,人工智能学术研究方兴未艾,技术迅猛发展,可谓万木争荣,日新月异。对于AI从业者来说,在广袤的知识森林中,系统梳理脉络,才能更好地把握趋势。为此,我们精选国内外优秀的综述文章,开辟“综述专栏”,敬请关注。

来源:知乎—管他叫大靖

地址:https://zhuanlan.zhihu.com/p/374956278


01

准备
近几年,对比学习在CV和NLP中掀起了一股热潮,其学习到的表征向量在多个下游任务上取得了state of art的效果。本文讨论的重心是自监督对比学习,由于不依赖标注数据,其实际上是一种无监督学习方法,在工业中有广阔的应用场景。同时,本文也会讨论监督对比学习,其和自监督对比学习的主要差异在于标注样本,并且能获得更好的效果。深入分析后发现,对比学习和多分类学习有着切不断的联系,不管是contrastive loss和softmax loss还是二者背后的motivation,都紧密相关。因此,本文会先从softmax loss出发,推导softmax loss并讨论其性质;再介绍contrastive loss;最后回到监督对比学习和自监督对比学习。

1.1 多分类监督学习 Multiclass Classification

多分类是机器学习中最常见的学习任务,多分类任务的损失函数一般为softmax loss。下面这篇文章推导了softmax loss的公式,并介绍了它的一些改进形式。
Softmax Loss推导及扩展

https://zhuanlan.zhihu.com/p/374018199

1.2 监督对比学习 Supervised Contrastive Learning

监督对比学习是一种机器学习技术,通过同时教模型哪些数据点相似或不同,来学习数据集的一般特征。监督对比学习是一种监督学习方法,对样本质量要求较高。相比于一般监督学习,监督对比学习通常能获得更好的效果。

对比学习可视化

1.3 自监督学习 Self-Supervised Learning

在大多数实际场景中,我们没有为每个样本添加标签。以医学成像为例,获取样本的难度很大,为了创建标签,专业人士不得不花费无数的时间来手动分类、分割图像。
虽然生成带有干净标签的数据集是昂贵的,但是我们时时刻刻都在生成大量的未标记的数据。自监督学习是让我们能够从这些未标记的数据中学习知识的一种方法。要利用这些大量的未标记数据,一种方法是适当地设定学习目标,以便从数据本身获得监督。
在NLP中,word2vec, Mask Language Model就是典型的自监督学习。在CV中也有很多自监督学习的例子,比如:将一张图片切成小的blocks,预测blocks之间的关系、将小的blocks拼图成原图等。

1.4 自监督对比学习 Self-Supervised Contrastive Learning

自监督对比学习和监督对比学习整体上是相似的,不同之处在于:自监督对比学习通过数据增强来构建有标签的样本,而监督对比学习的有人工标注的对比样本。因此,自监督对比学习的核心和难点是构建优质的对比样本。在CV中,一般通过剪切、旋转、高斯噪声、遮掩、染色等操作生成正样本。在NLP中一般通过回译、对字符的增、删、改来添加噪声,从而生成正样本。由于文本的离散性质,一般较难生成好的标签不变的增强样本。


02

监督对比学习 Supervised Contrastive Learning

2.1 对比损失V0 Contrastive Loss

Dimensionality Reduction by Learning an Invariant Mapping(2006)
这篇文章介绍了把样本从高维空间映射到低维空间的一种通用方法。模型仅依赖于样本的近邻关系,不需要在输入空间中进行任何距离度量。该方法使用基于能量的模型,模型使用给定的近邻关系学习映射函数。设映射函数为    ,参数为    ,学习的目标是找到一个    值,使得投影欧几里得流形上点之间的距离  近似输入空间中的近邻关系。其损失函数为:
其中,  表示两个样本是近邻关系,  表示两个样本不是近邻关系。  是margin参数。对于不相似的样本pair,两个样本离的越近,损失越大。当  时候,loss为零,此时两个样本位于彼此的margin半径外。对于相似的样本pair,两个样本离的越远,损失越大。
下面这张图能更直观的看出contrastive的思路:即让相似的样本pair越来越近(pull),让不相似的样本pair越来越远(push),直到超过margin。
对于多分类任务,Constractive Loss经常会在训练集上过拟合。针对该问题的改进方法有Triplet Loss、四元组损失(Quadruplet loss)、难样本采样三元组损失(Triplet loss with batch hard mining, TriHard loss) 等。

2.2 三元损失 Triplet Loss

In Defense of the Triplet Loss for Person Re-Identification(2017)
https://arxiv.org/pdf/1703.07737.pdf
Triplet loss和quadruplet loss是行人重识别中(Person Re-identification, ReID) 提出的损失函数。行人重识别也称行人再识别,是利用计算机视觉技术判断图像或者视频序列中是否存在特定行人的技术。行人重识别的数据特点是:同一个人有不同颜色、角度和姿态的图片,这些图片表表示同一个人,因此自然而然有很多正样本pair。
Triplet Loss是深度学习中的一种损失函数,用于训练差异性较小的样本,如人脸等。输入数据包括锚(Anchor)示例、正(Positive)示例、负(Negative)示例,通过优化锚示例与正示例的距离小于锚示例与负示例的距离,实现样本的相似性计算。
输入是一个三元组 <a, p, n>:
  • a:anchor, 锚点样本
  • p:positive, 与 a 是同一类别的样本
  • n:negative, 与 a 是不同类别的样本
triplet loss的难点是挖掘困难样本(hard sampling)。比如:同一个人,姿态不同,或者着装变了,这就是困难正样本。两个不同的人,穿着同样的衣服,拍摄角度相同,就是困难负样本。对于batch中的每个样本,在形成三元组以计算损失时,我们可以选择batch中最hard的正样本和最hard的负样本:
使用hard样本:
使用所有样本:
Lifted Embedding loss (Soft版本):
Generalization of the Lifted Embedding loss(Soft版本):

2.3 四元损失Quadruplet losses

Beyond Triplet Loss: A Deep Quadruplet Network for Person Re-Identification(2017)
https://openaccess.thecvf.com/content_cvpr_2017/papers/Chen_Beyond_Triplet_Loss_CVPR_2017_paper.pdf
每次输入四个数据,这四个数据都是不同的,其中包含了一对相似数据和一对不相似数据。根据四个数据在模型的输出和对应的相似信息计算损失。
直观的看,当让Ref和Pos靠近的同时让Neg和Neg2远离,并且  ,其损失函数为:
第一项就是triplet loss,它着重于优化同一样本的正负对之间的相对距离。第二项是新约束,它考虑了具有不同label的样本的正负对的距离。在此约束的下,最小的类间距离都必须大于最大的类内距离,进一步增加 intra-class discrepancy。实际应用中一般  。


03

自监督对比学习 Self-Supervised Contrastive Learning

3.1 对比损失 Contrastive loss

其中  是encoder函数,将样本映射到低维空间或低维球表面。  是正样本对,  是从样本分布  中采样的样本,  是  个负样本对。
对比损失会使学习的positive pair的特征相互靠近(pull),同时将来自随机采样的negative pair的特征推开(push)。在下面对齐性和均匀性那一节,我们会具体介绍contrastive loss的性质。
contrastive loss 和 softmax loss有什么关系呢? 记  为  从而有:

这不就是一个  的多分类问题么,  要从  这  个类别中识别出正确的那一个类别  ,即自己所在的类别。与softmax不同,对于每一个样本  这里的  都是不同的。

3.2 对齐性和均匀性 Alignment and Uniformity

Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere
http://proceedings.mlr.press/v119/wang20k/wang20k.pdf
Alignment: 一个正对的两个样本应该映射成近邻特征向量,并且特征对于噪声因子是不变的。
Uniformity: 特征向量应大致均匀地分布在单位超球面  上,并尽可能保留多的数据信息。
Alignment Uniformity 和 softmax loss  这篇文章中的intra-class compactness and inter-class separability是一致的,并不是一个新的概念,但这篇文章给了这两个概念更加数学化的表达,并证明了它们是可以实现的。
记  是encoder函数:

将contrastive loss拆成两项,  和  :
不难看出优化  就是让相似样本的欧氏距离尽可能小。文章指出,使得  最小的分布将弱收敛到球面上的均匀分布,并且  有如下渐进性质:
Alignment 和 Uniformity的可视化:
文章指出,在下游任务中表现好的模型具有更小的  和  :
接下来具体介绍几个CV和NLP中contrastive learning model.

3.3 SimCLR

3.3.1 SimCLR

A Simple Framework for Contrastive Learning of Visual Representations
http://proceedings.mlr.press/v119/chen20j/chen20j.pdf
1. Data Augmentation: 通过数据增强(crop,resize,recolor)来形成自监督样本
2. Encoding: 将这两幅图像输入深度学习模型(Big-CNN,如ResNet),为每幅图像创建向量表示。
3. Loss Minimization of Representations: 通过最小化对比损失函数来最大化两个向量表示的相似性。

3.3.2 SimCSE

SimCSE: Simple Contrastive Learning of Sentence Embeddings
https://arxiv.org/pdf/2104.08821.pdf
SimCSE 是 SimCLR在NLP中的应用,无监督的SimCSE并没有使用上文提到增、删、改等数据增强方法,而是使用dropout来给样本添加噪声。
We pass the same input sentence to the pre- trained encoder twice and obtain two embeddings as “positive pairs”, by applying independently sampled dropout masks.

3.4 CPC

Representation Learning with Contrastive Predictive Coding
https://arxiv.org/pdf/1807.03748.pdf
CPC用到了 Noise-Contrastive Estimation(NCE) Loss, 在这篇文献里定义为InfoNCE:
Given a set  of  random samples containing one positive sample from  and  negative samples from the 'proposal' distribution  we optimize:
模型结构:
1. Encoder: First, we compress high-dimensional data into a much more compact latent embedding space in which conditional predictions are easier to model.
2. Contex: Secondly, we use powerful autoregressive models in this latent space to make predictions many steps in the future.
3. Loss: Finally, we rely on Noise-Contrastive Estimation for the loss function in similar ways that have been used for learning word embeddings in natural language models, allowing for the whole model to be trained end-to-end.
文章指出,优化contrastive loss其实是在优化互信息的一个下界 (lower bound):



3.5 InfoMax Princle

InfoMax 准则像极大似然准则、最大熵准则一样,是指导机器学习确定优化目标、构建损失函数的一种基本准则。InfoMax 准则通过极大化互信息来优化参数:


3.5.1 CV 中的应用

Learning deep representations by mutual information estimation and maximization
https://arxiv.org/abs/1808.06670
1. Mutual information maximization: Find the set of parameters, ψ, such that the mutual information,  , is maximized. Depending on the end-goal, this maximization can be done over the complete input, X, or some structured or “local” subset.
2. Statistical constraints: Depending on the end-goal for the representation, the marginal 
 should match a prior distribution, V. Roughly speaking, this can be used to encourage the output of the encoder to have desired characteristics (e.g., independence).
DIM通过图像中的局部特征来构造对比学习任务。具体来说,要求模型判别全局特征和局部特征是否来自于同一幅图像:
  • 全局特征为锚点(anchor), 
  • 正样本为来自同一张图像的局部特 
  • 负样本为来自不同图像的局部特征, 

3.5.2 NLP 中的应用

A Mutual Information Maximization Perspective of Language Representation Learning
https://arxiv.org/pdf/1910.08350.pdf
DIM通过最大化局部语义和全局语义的互信息来预训练语言模型。
  • 随机选择训练语料中的n-gram  ,将其mask,作为一个view(作为全局语义)
  • 将mask掉的n-gram提取出来,作为另一个view(作为局部语义)
  • 这两个view都通过Transformer进行编码,同时随机采样其他文本中的n-gram作为负样本,要求模型识别出正确的n-gram


04

一些脑洞

4.1 Word2Vec

In addition, we present a simplified variant of Noise Contrastive Estimation (NCE) for training the Skip-gram model that results in faster training and better vector representations for frequent words, compared to more complex hierarchical softmax that was used in the prior work.
The main difference between the Negative sampling(NEG) and NCE is that NCE needs both samples and the numerical probabilities of the noise distribution, while Negative sampling uses only samples.
NEG通过二分类来拉开positive sample和negative samples,如果通过多分类来做这件事,损失函数就成了contrastive loss:


4.2 GAN

将loss改成contrastive loss:


05

总结
本文会从softmax loss出发,推导了softmax loss的公式、讨论其性质、并介绍了一些改进方法。然后,介绍contrastive loss及其早期在降维中的应用。随着深度学习的发展,contrastive loss有了新的formulation,我们继续讨论了contrastive loss的对齐性和均匀性,并介绍中其在CV和NLP中的具体应用。值得注意的是,对比学习具有久远的历史,在近几年的深度学习中,对比学习先是从CV流行起来的,现在NLP领域正在借鉴、改进相关的方法。

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“综述专栏”历史文章


更多综述专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存