查看原文
其他

【源头活水】Deep InfoMax损失函数小记

“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。

来源:知乎—罗驳思
地址:https://zhuanlan.zhihu.com/p/157871559
无监督表示学习随着深度学习的迅猛发展,近年来呈现出百花齐放之态。19年ICLR会议中收录的Learning Deep Representations by Mutual Information Estimation and Maximization这篇Bengio组大作,从互信息的视角给出了一种新的思路。
关于Deep InfoMax(DIM)的解读已有珠玉在前, 只是多数分析的博客更关注本文的理论推导,而在这篇文章的实现上,仅凭理论上的指引,我无法复原出loss function的细节,于是在github搜了本文的源码。原作者的源码封装较好,略略一读有些头大,于是我转而研读另一版的pytorch代码,初步明确了DIM的loss function一种实现方式。
原文中,DIM的优化目标是:

论文中的等式(8)

根据论文可知,上式中的字母含义分别为:
   :表示用于提取输入图像特征的编码器Encoder。
   :输入的原始图像。
从输入空间得到的训练样本集
   :调节各部分占比的超参数。
   :Mutual Information(MI) estimator,互信息估计器。
   :全局和局部目标的判别器模型的参数。
   :看论文示意图中M x M的feature map可知,    就是低层特征的数量。
   :第i个低层特征(共    个)。
   :仿照adversarial autoencoders(AAE)设计的Discriminator。
 :原始数据集的经验概率分布。


 :将样本从分布 推向    时得到的边际分布。
先验分布。
 :一个由神经网络构建的判别器,参数为 。
 :从分布 


若用文中提到的Jensen-Shannon MI estimator来最大化前一项,即:

论文中的等式(4)

其中sp表示softplus函数:

上述公式似乎并不那么直观易懂_(:з」∠)_接下来我们就根据代码,看看该如何实现这个优化问题。
这个版本的代码里,作者直接定义了DeepInfoMaxLoss类
class DeepInfoMaxLoss(nn.Module): def __init__(self, alpha=0.5, beta=1.0, gamma=0.1): super().__init__() self.global_d = GlobalDiscriminator() self.local_d = LocalDiscriminator() self.prior_d = PriorDiscriminator() self.alpha = alpha self.beta = beta self.gamma = gamma
def forward(self, y, M, M_prime):
# see appendix 1A of https://arxiv.org/pdf/1808.06670.pdf
y_exp = y.unsqueeze(-1).unsqueeze(-1) y_exp = y_exp.expand(-1, -1, 26, 26)
y_M = torch.cat((M, y_exp), dim=1) y_M_prime = torch.cat((M_prime, y_exp), dim=1)
Ej = -F.softplus(-self.local_d(y_M)).mean() Em = F.softplus(self.local_d(y_M_prime)).mean() LOCAL = (Em - Ej) * self.beta
Ej = -F.softplus(-self.global_d(y, M)).mean() Em = F.softplus(self.global_d(y, M_prime)).mean() GLOBAL = (Em - Ej) * self.alpha
prior = torch.rand_like(y)
term_a = torch.log(self.prior_d(prior)).mean() term_b = torch.log(1.0 - self.prior_d(y)).mean() PRIOR = - (term_a + term_b) * self.gamma
        return LOCAL + GLOBAL + PRIOR
最后一行告诉我们:loss function由LOCAL+GLOBAL+PRIOR三部分构成。
在研读loss的定义前,先确认下它的输入:
for x, target in batch: x = x.to(device)
optim.zero_grad() loss_optim.zero_grad() y, M = encoder(x) # rotate images to create pairs for comparison M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0) loss = loss_fn(y, M, M_prime) train_loss.append(loss.item()) batch.set_description(str(epoch) + ' Loss: ' + str(stats.mean(train_loss[-20:]))) loss.backward() optim.step()            loss_optim.step()
该类的forward方法接受三个输入:y, M, M_prime。在main函数中,y和M源自encoder(x)。x就是一个batch的原始图片输入,假设batch size为64,则x的size就是[64, 3, 32, 32]。看下Encoder类的定义:
class Encoder(nn.Module): def __init__(self): super().__init__() self.c0 = nn.Conv2d(3, 64, kernel_size=4, stride=1) self.c1 = nn.Conv2d(64, 128, kernel_size=4, stride=1) self.c2 = nn.Conv2d(128, 256, kernel_size=4, stride=1) self.c3 = nn.Conv2d(256, 512, kernel_size=4, stride=1) self.l1 = nn.Linear(512*20*20, 64)
self.b1 = nn.BatchNorm2d(128) self.b2 = nn.BatchNorm2d(256) self.b3 = nn.BatchNorm2d(512)
def forward(self, x): h = F.relu(self.c0(x)) features = F.relu(self.b1(self.c1(h))) h = F.relu(self.b2(self.c2(features))) h = F.relu(self.b3(self.c3(h))) encoded = self.l1(h.view(x.shape[0], -1)) return encoded, features
forward()返回的encoded就是Encoder网络最后提取出的特征,size是[64, 64];features是网络中间层特征,size为[64, 128, 26, 26]。也就是说:
y表示图片的全局特征;
M表示图片的中间层特征;
M_prime就是将每个batch的第一张图片对应的中间层特征置于该batch特征的末尾(由此构造出DeepInfoMax中,用于生成"Fake" pair的another image)。
明确了这三个原始输入后,我们回过头看看loss function的构成。
首先是GLOBAL:

横线表示求均值,    就是代码中的global_d,是类GlobalDiscriminator的实例:
class GlobalDiscriminator(nn.Module): def __init__(self): super().__init__() self.c0 = nn.Conv2d(128, 64, kernel_size=3) self.c1 = nn.Conv2d(64, 32, kernel_size=3) self.l0 = nn.Linear(32 * 22 * 22 + 64, 512) self.l1 = nn.Linear(512, 512) self.l2 = nn.Linear(512, 1)
def forward(self, y, M): h = F.relu(self.c0(M)) h = self.c1(h) h = h.view(y.shape[0], -1) h = torch.cat((y, h), dim=1) h = F.relu(self.l0(h)) h = F.relu(self.l1(h))        return self.l2(h)
对每张图片的特征对  ,forward()把中间层特征M先做卷积,然后和全局特征拼接再经过线性层,输出得到实数值。
其次是LOCAL:

 :将中间层特征M与最后一层特征y拼接得到。

 :将另一张图片的中间层特征M_prime和最后一层特征y拼接得到。

   就是代码中的LocalDiscriminator类的实例local_d:
class LocalDiscriminator(nn.Module): def __init__(self): super().__init__() self.c0 = nn.Conv2d(192, 512, kernel_size=1) self.c1 = nn.Conv2d(512, 512, kernel_size=1) self.c2 = nn.Conv2d(512, 1, kernel_size=1)
def forward(self, x): h = F.relu(self.c0(x)) h = F.relu(self.c1(h))        return self.c2(h)
它直接对拼接后的特征进行卷积,返回一个实值。
最后是PRIOR:

   :值服从U(0, 1)均匀分布的随机向量,与y的size保持一致。
   表示PriorDiscriminator类的实例prior_d:
class PriorDiscriminator(nn.Module): def __init__(self): super().__init__() self.l0 = nn.Linear(64, 1000) self.l1 = nn.Linear(1000, 200) self.l2 = nn.Linear(200, 1)
def forward(self, x): h = F.relu(self.l0(x)) h = F.relu(self.l1(h))        return torch.sigmoid(self.l2(h))
其返回值是经过线性层之后,再经过sigmoid调整至(0, 1)值域内的实数。
现在,我们把GLOBAL、LOCAL、PRIOR对应回原先的优化目标(原文中的等式(8)),可知:

代码中的encoder、global_d、local_d、prior_d分别对应于  四项。在实际计算过程中,要求loss function最小化,因此将论文中等式(4),即JSD MI estimator中的负号变为正号,并减去论文中的等式(7)这一项:

原论文中等式(7)

对照着论文的示意图来理解优化目标:
假设"Real"、"Fake"分别指向1(True)和0(False)。要使得loss减小,则:
对于GLOBAL,    趋向Real,    趋向Fake。
对于LOCAL,    趋向Real,    趋向Fake。
对于PRIOR,    趋向于1,    趋向于0。

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“源头活水”历史文章


更多源头活水专栏文章,

请点击文章底部“阅读原文”查看



分享、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存