查看原文
其他

【强基固本】深度学习从入门到放飞自我:完全解析triplet loss

“强基固本,行稳致远”,科学研究离不开理论基础,人工智能学科更是需要数学、物理、神经科学等基础学科提供有力支撑,为了紧扣时代脉搏,我们推出“强基固本”专栏,讲解AI领域的基础知识,为你的科研学习提供助力,夯实理论基础,提升原始创新能力,敬请关注。

作者:知乎—刘昕宸

地址:https://www.zhihu.com/people/liu-xin-chen-64


本文参考自(感兴趣的同学可直接阅读英文原文):
https://omoindrot.github.io/triplet-loss
在人脸识别领域,triplet loss常被用来提取人脸的embedding。
之前实验室有个做无监督特征学习的小任务,因为没有类别的监督信息,因此也可以用triplet loss来设计约束,以期得到discriminative embedding。
triplet loss原理是比较简单的,关键在于搞懂各种采样triplets的策略。

01

为什么不使用softmax呢?
通常在有监督学习中,我们有固定数量的类别(比如针对Cifar10的图像分类任务,类别数就是10),因此在训练网络时我们通常会在最后一层使用softmax,并结合cross entropy loss作为监督信息。
但是在有些情况下我们需要能够有一个变化数量的类别。比如对实例人脸识别(face recognition for instance),我们需要能够比较两张未知的脸,然后判断这两张脸是否来自同一个人。这种情况下使用triplet loss就可以获得不错的表征,在表征空间,同一个人的脸的表征彼此接近,而不同人的脸的表征则能够被很好分隔。
FaceNet: A Unified Embedding for Face Recognition and Clustering 这篇论文就介绍了triplet loss在人脸识别领域的应用,感兴趣的同学可以关注。
其实除此之外,在无监督学习应用中,也常使用triplet loss来表征。

02

loss定义
  • anchor是基准
  • positive是针对anchor的正样本,表示与anchor来自同一个人
  • negative是针对anchor的负样本
以上  共同构成一个triplet.
triplet loss的目标是使得:
  • 具有相同label的样本,它们的embedding在embedding空间尽可能接近
  • 具有不同label的样本,它们的embedding距离尽可能拉远
对于embedding空间的某个距离  ,一个triplet的loss可以定义为:
最小化loss  的目标是:使得  接近  ,  大于  .
一旦  成为"easy negative",loss就会变成  .

03

Triplet mining
根据loss的定义,我们可以定义3种类型的triplet:
  • easy triplets: 此时loss为  ,这种情况是我们最希望看到的,可以理解成是容易分辨的triplets。即 
  • hard triplets: 此时negative比positive更接近anchor,这种情况是我们最不希望看到的,可以理解成是处在模糊区域的triplets。即 
  • semi-hard triplets: 此时negative比positive距离anchor更远,但是距离差没有达到一个margin,可以理解成是一定会被误识别的triplets。
以下这张图非常清晰地说明了:
确定anchor和positive,negative的位置决定了triplet的类型:easy, semi-hard和hard
据此我们也可以把negative划分为:easy negative, semi-hard negative, hard negative
因此,基于哪种类型的triplet训练将非常大地影响我们的实验效果。
在FaceNet这篇论文种,作者是为每一对anchor和positive随机选择一个semi-hard negative构建semi-hard triplet,并在这些triplets上训练。

04

Offline and online triplet mining
目前我们已经定义了一种基于triplet embedding的loss,接下来最重要的问题就是我们该采样什么样的triplet?我们该如何采样目标triplet?等
Offline triplet mining
第一种生产triplet的方式:offline mining
详细说明:
在训练每个epoch的开始阶段,计算训练集种所有的embedding,并挑选出所有的hard triplets和semi-hard triplets,并在该epoch内训练这些triplets.
这种方式不是很高效,因为每个epoch我们都需要遍历整个数据集来生产triplets。
Online triplet mining
第二种生产triplet的方式:online mining
详细说明:
这种想法就是对于每个batch的输入,动态地计算有用的triplets。给定batch size为  的样本,我们计算其对应的  embeddings,此时我们最多可以找到  triplets. 当然这其中很多triplet都不是合法的(因为triplet中需要有2个是相同label,1个是不同label)
这种方式显然比offline mining高效得多
接下来我们来仔细讨论讨论online mining策略:
我们主要讨论如何从一个batch的B embeddings中产生triplets。
我们先设三个索引:
其中  和   具有相同label,  具有不同label,此时  被称为是一个valid triplet.
我们接下来基于valid triplets来挑选triplets,计算loss
我们还是以人脸识别为例,假设一个batch的人脸,batch-size为  ,由  个人,每个人  张人脸图像构成。一般来说  .
  • batch all:挑选出所有合规的triplets,将hard triplets和semi-hard triplets的loss平均
    • 这个过程产生了  个合规的triplets,其中anchor有  个,positive  个,negative  个
  • batch hard: 对于每个anchor,挑选出hardest positive  ( 距离最大)和hardest negative  (  距离最小)。
    • 这个过程产生了  个triplets
    • 这种方式得到的triplets被称为这个batch中的hardest triplets

05

Code implementation
triplet loss原理还是比较简单的,因此我们不难直观实现如下:
anchor_output = ... # shape [None, 128]positive_output = ... # shape [None, 128]negative_output = ... # shape [None, 128]
d_pos = tf.reduce_sum(tf.square(anchor_output - positive_output), 1)d_neg = tf.reduce_sum(tf.square(anchor_output - negative_output), 1)
loss = tf.maximum(0.0, margin + d_pos - d_neg)loss = tf.reduce_mean(loss)
我们使用网络提取anchor, positive, negative的embedding。
anchor_output, positive_output, negative_output分别表示  anchor embeddings,  positive embeddings,  negative embeddings
以上这种方式虽然简单,但却非常低效,因为以上采用的是offline triplet mining
接下来我们来尝试下online triplet mining版本的triplet loss:
Compute distance matrix
triplet loss需要计算  和  ,因此我们需要高效地计算pairwise distance matrix
对于输入的  embeddings,我们期待得到  的距离矩阵。距离计算公式:
参数squared为True表示计算的是距离的平方,为False表示计算的是欧式距离
def _pairwise_distance(embeddings, squared=False): ''' 计算两两embedding的距离 ------------------------------------------ Args: embedding: 特征向量, 大小(batch_size, vector_size) squared: 是否距离的平方,即欧式距离 Returns: distances: 两两embeddings的距离矩阵,大小 (batch_size, batch_size) ''' # 矩阵相乘,得到(batch_size, batch_size),因为计算欧式距离|a-b|^2 = a^2 -2ab + b^2, # 其中 ab 可以用矩阵乘表示 dot_product = tf.matmul(embeddings, tf.transpose(embeddings)) # dot_product对角线部分就是 每个embedding的平方 square_norm = tf.diag_part(dot_product) # |a-b|^2 = a^2 - 2ab + b^2 # tf.expand_dims(square_norm, axis=1)是(batch_size, 1)大小的矩阵,减去 (batch_size, batch_size)大小的矩阵,相当于每一列操作 distances = tf.expand_dims(square_norm, axis=1) - 2.0 * dot_product + tf.expand_dims(square_norm, axis=0) distances = tf.maximum(distances, 0.0) # 小于0的距离置为0 if not squared: # 如果不平方,就开根号,但是注意有0元素,所以0的位置加上 1e*-16 distances = distances + mask * 1e-16 distances = tf.sqrt(distances) distances = distances * (1.0 - mask) # 0的部分仍然置为0
return distances
Online triplet mining strategy 1: batch all strategy
挑选合规的triplets,即:
  •  的  必须不等(也就是  不能是同一个样本)
  •  的label相同,  的label不同
此时代码中所谓的mask,就是一个3D tensor,就是将mask在合规triplet index位置置为1,其它位置置为0.
def _get_triplet_mask(labels): ''' 得到一个3D的mask [a, p, n], 对应triplet(a, p, n)是valid的位置是True ---------------------------------- Args: labels: 对应训练数据的labels, shape = (batch_size,)
Returns: mask: 3D,shape = (batch_size, batch_size, batch_size) ''' # 初始化一个二维矩阵,坐标(i, j)不相等置为1,得到indices_not_equal indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool) indices_not_equal = tf.logical_not(indices_equal) # 因为最后得到一个3D的mask矩阵(i, j, k),增加一个维度,则 i_not_equal_j 在第三个维度增加一个即,(batch_size, batch_size, 1), 其他同理 i_not_equal_j = tf.expand_dims(indices_not_equal, 2) i_not_equal_k = tf.expand_dims(indices_not_equal, 1) j_not_equal_k = tf.expand_dims(indices_not_equal, 0) # 想得到i!=j!=k, 三个不等取and即可, 最后可以得到当下标(i, j, k)不相等时才取True distinct_indices = tf.logical_and(tf.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k) # 同样根据labels得到对应i=j, i!=k label_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1)) i_equal_j = tf.expand_dims(label_equal, 2) i_equal_k = tf.expand_dims(label_equal, 1) valid_labels = tf.logical_and(i_equal_j, tf.logical_not(i_equal_k)) # mask即为满足上面两个约束,所以两个3D取and mask = tf.logical_and(distinct_indices, valid_labels)
return mask
我们需要一个形状为  的3D tensor,其中  表示triplet  的loss
我们通过_get_triplet_mask来获得合规triplets的index mask,从而获得合规的triplets
统计合规triplets中loss不为0.0的triplets,最后计算平均得到batch_all_triplet_loss.
def batch_all_triplet_loss(labels, embeddings, margin, squared=False): ''' triplet loss of a batch ------------------------------- Args: labels: 标签数据,shape = (batch_size,) embeddings: 提取的特征向量, shape = (batch_size, vector_size) margin: margin大小, scalar
Returns: triplet_loss: scalar, 一个batch的损失值 fraction_postive_triplets : valid的triplets占的比例 ''' # 得到每两两embeddings的距离,然后增加一个维度,一维需要得到(batch_size, batch_size, batch_size)大小的3D矩阵 # 然后再点乘上valid 的 mask即可 pairwise_dis = _pairwise_distance(embeddings, squared=squared) anchor_positive_dist = tf.expand_dims(pairwise_dis, 2) assert anchor_positive_dist.shape[2] == 1, "{}".format(anchor_positive_dist.shape) anchor_negative_dist = tf.expand_dims(pairwise_dis, 1) assert anchor_negative_dist.shape[1] == 1, "{}".format(anchor_negative_dist.shape) triplet_loss = anchor_positive_dist - anchor_negative_dist + margin mask = _get_triplet_mask(labels) mask = tf.to_float(mask) triplet_loss = tf.multiply(mask, triplet_loss) triplet_loss = tf.maximum(triplet_loss, 0.0) # 计算valid的triplet的个数,然后对所有的triplet loss求平均 valid_triplets = tf.to_float(tf.greater(triplet_loss, 1e-16)) num_positive_triplets = tf.reduce_sum(valid_triplets) num_valid_triplets = tf.reduce_sum(mask) fraction_postive_triplets = num_positive_triplets / (num_valid_triplets + 1e-16) triplet_loss = tf.reduce_sum(triplet_loss) / (num_positive_triplets + 1e-16)
return triplet_loss, fraction_postive_triplets
Online triplet mining strategy 2: batch hard strategy
在batch hard strategy中,我们期待为每个anchor找到hardest positive和hardest negative.
Hardest positive
step 1. 构建embedding pairwise距离矩阵
step 2. 计算合规pair的2D mask(合规要求:  且  和  具有相同的label). 距离矩阵在mask之外的位置均置为0
step 3. 此时取距离矩阵中每一行的最大值,其所对应的triplet,就是该行对应anchor的hardest positive.
Hardest negative
提取方法与上面hardest positive类似,不再赘述。
最后计算得到的triplet loss:
triplet_loss = tf.maximum(hardest_positive_dist - hardest_negative_dist + margin, 0.0)
所以batch hard策略计算triplet loss的代码实现如下所示:
def batch_hard_triplet_loss(labels, embeddings, margin, squared=False): """Build the triplet loss over a batch of embeddings. For each anchor, we get the hardest positive and hardest negative to form a triplet. Args: labels: labels of the batch, of size (batch_size,) embeddings: tensor of shape (batch_size, embed_dim) margin: margin for triplet loss squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. If false, output is the pairwise euclidean distance matrix. Returns: triplet_loss: scalar tensor containing the triplet loss """ # Get the pairwise distance matrix pairwise_dist = _pairwise_distances(embeddings, squared=squared)
# For each anchor, get the hardest positive # First, we need to get a mask for every valid positive (they should have same label) mask_anchor_positive = _get_anchor_positive_triplet_mask(labels) mask_anchor_positive = tf.to_float(mask_anchor_positive)
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p)) anchor_positive_dist = tf.multiply(mask_anchor_positive, pairwise_dist)
# shape (batch_size, 1) hardest_positive_dist = tf.reduce_max(anchor_positive_dist, axis=1, keepdims=True)
# For each anchor, get the hardest negative # First, we need to get a mask for every valid negative (they should have different labels) mask_anchor_negative = _get_anchor_negative_triplet_mask(labels) mask_anchor_negative = tf.to_float(mask_anchor_negative)
# We add the maximum value in each row to the invalid negatives (label(a) == label(n)) max_anchor_negative_dist = tf.reduce_max(pairwise_dist, axis=1, keepdims=True) anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
# shape (batch_size,) hardest_negative_dist = tf.reduce_min(anchor_negative_dist, axis=1, keepdims=True)
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss triplet_loss = tf.maximum(hardest_positive_dist - hardest_negative_dist + margin, 0.0)
# Get final mean triplet loss triplet_loss = tf.reduce_mean(triplet_loss)
return triplet_loss

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。



“强基固本”历史文章




分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存