查看原文
其他

【他山之石】神经网络解微分方程实例:三体问题

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者:知乎—xlvector

原文地址:https://www.zhihu.com/people/xlvector

三体问题(three body problem)是一个混沌系统,也就是说这个系统有蝴蝶效应,初始值的一点扰动会带来一定时刻之后的巨大变化。
三体问题由以下微分方程决定:

来自维基百科 https://en.wikipedia.org/wiki/Three-body_problem
然后给定三体的每个点的初始位置以及初始速度,求解每个物体随着时刻T变化的运动轨迹。
参考下面这篇文章对边界值的处理:
https://zhuanlan.zhihu.com/p/312946002
我们可以假设三体的位置和时间的方程是
简单推导以下,可以发现这个方程满足  ,其中  是三个物体的初始位置,  是三个物体的初始速度。然后,我们就可以写代码了。
首先定义用来拟合  的神经网络:
# 这里采用了 x + sin^(x)做激活函数,这个函数会取得比较好的精度class Act(nn.Module): def __init__(self): super(Act, self).__init__() self.sigmoid = nn.Sigmoid()
def forward(self, x): return th.sin(x).pow(2) + x
# 子结构里采用了ResNet,也有助于提高精度class ResNet(nn.Module): def __init__(self, dim): super(ResNet, self).__init__()
self.net = nn.Sequential( Act(), nn.Linear(dim, dim))
def forward(self, x): return x + self.net(x)
class Obj3Net(nn.Module): def __init__(self): super(Obj3Net, self).__init__() self.net = nn.ModuleList([ nn.Sequential( nn.Linear(1, 64), ResNet(64), ResNet(64), ResNet(64), ResNet(64), ResNet(64), Act(), nn.Linear(64, 1)) for i in range(6)])
然后,定义解微分方程的损失函数:
class Obj3Net(nn.Module): def forward(self, t, x0, dx0): t.requires_grad_() t2 = t * t x = [self.net[i](t) * t2 + x0[i] + dx0[i] * t for i in range(6)] w = th.ones_like(x[0])
dx = [tg.grad(x[i], t, grad_outputs = w, create_graph = True)[0] for i in range(6)] dx2 = [tg.grad(dx[i], t, grad_outputs = w, create_graph = True)[0] for i in range(6)] x02 = x[0] - x[2] x13 = x[1] - x[3] x04 = x[0] - x[4] x15 = x[1] - x[5] x24 = x[2] - x[4] x35 = x[3] - x[5]
d12 = th.sqrt(x02.pow(2) + x13.pow(2)) d13 = th.sqrt(x04.pow(2) + x15.pow(2)) d23 = th.sqrt(x24.pow(2) + x35.pow(2))
norm = th.min(d12.pow(2), d13.pow(2)) norm = th.min(norm, d23.pow(2)).detach()
d12 = norm / d12.pow(3) d13 = norm / d13.pow(3) d23 = norm / d23.pow(3)
p12_x = x02 * d12 p12_y = x13 * d12 p13_x = x04 * d13 p13_y = x15 * d13
p21_x = -x02 * d12 p21_y = -x13 * d12 p23_x = x24 * d23 p23_y = x35 * d23
p31_x = -x04 * d13 p31_y = -x15 * d13 p32_x = -x24 * d23 p32_y = -x35 * d23
loss = (dx2[0] * norm + p12_x + p13_x).pow(2) \ + (dx2[1] * norm + p12_y + p13_y).pow(2) \ + (dx2[2] * norm + p21_x + p23_x).pow(2) \ + (dx2[3] * norm + p21_y + p23_y).pow(2) \ + (dx2[4] * norm + p31_x + p32_x).pow(2) \ + (dx2[5] * norm + p31_y + p32_y).pow(2)
return loss.mean(), loss.max(), [e.detach() for e in x], [e.detach() for e in dx]
上面这个损失函数只考虑了二维的情况。如果要搞三维,可以自己改改。
训练过程中会发现一些问题:
  1. 如果一次要拟合的时间段很长,会出现很难收敛的问题,所以最好分段拟合,比如每次向前拟合长度为T的时间,然后得到T时刻后的位置,速度,作为下次拟合的初始值。
  2. 拟合精度很重要,三体是一个混沌系统,精度太差,会发现结果相差很大。
  3. 微分方程里有  这一项,其中r是物体之间的距离。如果距离太近,会导致接近无穷,方程发散。我上面的代码已经处理了这种情况。
最后发个结果,这是三体里一个著名的周期性稳定解:

这个地方的精度达到了1e-6,但是可以发现,随着时间的推移,他们的轨迹还是有稍许偏离的。

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。




“他山之石”历史文章


更多他山之石专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存