查看原文
其他

【源头活水】通过样本有效估计高维连续数据互信息



“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。

来源:知乎—汤圆不说话
地址:https://zhuanlan.zhihu.com/p/412538959


01

引入
在阅读过信息瓶颈的论文,了解过基本原理过后,我相信大部分人的第一想法当然是上手试一试,亲自感受信息在神经网络中的流动方式,希望了解信息瓶颈的基本内容可以参考我的上一篇文章:
https://zhuanlan.zhihu.com/p/409861142
但是,当我们开始将信息瓶颈与神经网络联系起来之前,我们会发现一个略显尴尬的问题,两个随机变量的互信息,是信息瓶颈理论中最关键的量,这个量的计算并不是一件非常容易的事情。
互信息的定义基于随机变量的联合与边缘分布,而在工程中,神经网络的输入数据往往高维连续,我们已有的样本(数据集)对于我们希望得到的分布而言实在太稀疏了!
“参数估计”,“假设检验“,等来自于统计的方法,经过一定的实践,这一条路——希望通过样本估计出分布,再通过分布来获取互信息,是一条不归路。
通过查阅资料,最终找到了一个有神经网络背景的估计方法——MINE,利用神经网络来分析神经网络,套娃了属于是。
MINE能够有效的最主要原因,是其放弃了我们之前通过样本估计总体的思路,绕过中间步骤,直接将随机变量的期望与互信息建立联系——毕竟通过样本估计期望要比通过样本估计分布要可靠太多,这种思想与”随机积分“的思想有异曲同工之妙。
其对于神经网络应用的方法也很带给人以启迪,我认为这种应用应该普遍适用于解决一部分数学问题,待相关领域同志继续发现了。

02

方法原理
MINE主要基于以下定理,分布    与分布    的    散度可以用一个函数    对于两个分布的期望来表示,选择一个最优的函数    ,使得下列右侧表达式的值最大,这个值就是    散度的值。

KL散度的期望表示

如果将    变为    的联合分布    ,将    变为    边缘分布乘积,    ,其    散度就是    之间的互信息了(这是互信息的定义)。
这个定理告诉我们,不必知道联合分布与边缘分布的具体情况,只要找到那个最优的函数    ,就能够通过两个随机变量的期望来得到互信息,这样,我们就成功绕过了最艰难的一步。
接下来唯一的问题就是如何找到使得整个式子最大的那个最优的函数。
敏感的读者应该已经猜到了,这个解决方案那就是神经网络!通过一个神经网络来拟合    ,将上面右式作为损失函数,通过梯度上升,使得损失值直至收敛达到最大,也就是使得右式达到上界,这个最大值就是互信息的估计量。
以下是论文中提供的算法伪码:

MINE算法伪代码

MINE方法的主要原理就是如此,其余还有一些细节,包括采样方式,梯度上升方式,相关证明等等,就请读者阅读论文原文了,链接附在文末。

03

方法实现
MINE作为一种普遍有效的方法,非常实用,其论文在后半部分也进行了大量的实验,其应用包括:
  • 最大化互信息以改善 GAN网络
  • 最大化互信息来改进双向对抗网络
  • 信息瓶颈
接下来我将附上代码帮助大家进行简单实验:
在这个实验中,    是5维的具有一定方差的高斯随机变量(相互独立),    是取    的前3维,并添加一定方差    的高斯噪声,我们通过MINE估计两者的互信息。
引入相关库:
import numpy as npimport torchimport torch.nn as nnfrom tqdm import tqdmfrom torch.utils.tensorboard import SummaryWriter
定义网络(互信息单位为bits,代码中进行了相应转换):
#正向传播网络class Net(nn.Module): def __init__(self,x_dim,y_dim) : super((Net), self).__init__() self.layers = nn.Sequential( nn.Linear((x_dim+y_dim), 10), nn.ReLU(), nn.Linear(10, 1))
def forward(self, x, y): batch_size = x.size(0) tiled_x = torch.cat([x, x, ], dim=0) idx = torch.randperm(batch_size)
shuffled_y = y[idx] concat_y = torch.cat([y, shuffled_y], dim=0) inputs = torch.cat([tiled_x, concat_y], dim=1) logits = self.layers(inputs)
pred_xy = logits[:batch_size] pred_x_y = logits[batch_size:] loss = - np.log2(np.exp(1)) * (torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y))))        return loss
#估计器class Estimator(): def __init__(self,x_dim,y_dim) -> None: self.net = Net(x_dim,y_dim) self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.01) self.x_dim = x_dim self.y_dim = y_dim
def backward(self,x,y): loss = self.net(x, y) self.net.zero_grad() loss.backward() self.optimizer.step() info = -loss.detach() return info
定义数据,初始化网络,并定义采样函数
power = 3noise = 0.5n_epoch = 2000batch_size = 10000x_dim = 5y_dim = 3

writer = SummaryWriter('./log')estimator = Estimator(x_dim,y_dim)
def gen_x(num, dim ,power): return np.random.normal(0., np.sqrt(power), [num, dim])
def gen_y(x, num, dim,noise): return x[:,:dim] + np.random.normal(0., np.sqrt(noise), [num, dim])
def true_mi(power, noise, dim): return dim * 0.5 * np.log2(1 + power/noise)
#互信息真实值mi = true_mi(power, noise, y_dim)print('True MI:', mi)
开始训练,生成数据,反向传播
for epoch in tqdm(range(n_epoch)): x_sample = gen_x(batch_size, x_dim, power) y_sample = gen_y(x_sample, batch_size, y_dim ,noise)
x_sample = torch.tensor(x_sample,dtype=torch.float32) y_sample = torch.tensor(y_sample,dtype=torch.float32)
info = estimator.backward(x_sample,y_sample)
writer.add_scalar('info',info,epoch) writer.add_scalar('true info',mi,epoch)
使用不同的noise进行实验,不同图示所代表的噪声值如下:

其对应的信息量估计如下,其中横坐标代表训练步数(曲线进行了一定的平滑处理):
其互信息真实值为:
从以上图中,可以看出MINE表现不错,在噪声越大的情况下,    与    的互信息变小,至少互信息估计的相对关系是变现出来了,在实验中较为低维情况下,与真实值也比较相近。
然而在更多的实验中个人发现,MINE对于高维互信息估计仍然偏小,并且容易出现不稳定发散的情况。
虽然论文证明了我们总可以通过更多的采样以及更多参数的神经网络使得估计有效,但是针对什么样的数据,到底什么样的神经网络够用,那可就真的是天知道了。有趣的是,这正是我们希望通过研究神经网络可解释性希望搞清楚的事情,套娃了属于是。
不过无论怎么样,在目前阶段下,MINE也是我们能够使用的,有效并且简单朴实的互信息估计方案,路还要一步一步继续走下去,MINE已经给我们插上了一对翅膀。

参考文献

论文原文:

https://arxiv.org/abs/1801.04062v1

博客参考:

https://zhuanlan.zhihu.com/p/113455332

https://zhuanlan.zhihu.com/p/191155238

代码参考:

https://github.com/gtegner/mine-pytorch

https://cloud.tencent.com/developer/article/1827058

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“源头活水”历史文章


更多源头活水专栏文章,

请点击文章底部“阅读原文”查看



分享、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存