查看原文
其他

【他山之石】Pytorch转ONNX-实战篇(tracing机制)

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者:立交桥跳水冠军

地址:https://www.zhihu.com/people/li-jiao-qiao-tiao-shui-guan-jun


本文重点结合OpenMMlab系列中用到的Pytorch转ONNX的小技巧来介绍实战部分。

01

tracing的机制

上文提到过,Pytorch转ONNX的方式是基于tracing(追踪),通俗来说,就是ONNX的相关代码在一旁看着Pytorch跑一遍,运行了什么内容就把什么记录下来。但是在这里并不是所有Python的运行内容都会被记录。举个例子,下面的代码中,
c = torch.matmul(a, b)print("Blabla")e = torch.matmul(c, d)
其中只有第1,3行相关的内容会被记录,因为只有他们是和Pytorch相关的,而第二行只是普通的python语句。
具体来说,只有ATen操作会被记录下来。ATen可以被理解为一个Pytorch的基本操作库,一切的Pytorch函数都是基于这些零部件构造出来的(比如ATen就是加减乘除,所有Pytorch的其他操作,比如平方,算sigmoid,都可以根据加减乘除构造出来)
*之前说的ONNX无法记录if语句的问题也是因为if并不是Aten中的操作
虽然ONNX可以记录所有Pytorch的执行(即记录所有ATen操作),但是在输出的时候会做一个剪枝,把没用的操作剪掉
举个例子,下面的程序,显而易见第一句话是没有用的。
t1 = torch.matmul(a, b)t2 = torch.matmul(c, d)return t2
ONNX会在得到全部的操作以及他们之间的输入输出关系后(以DAG作为表示),根据DAG的输出往前推,做遍历,所有可以被遍历到的节点被保留,其他节点直接扔掉。
在MMDetection(https://github.com/open-mmlab/mmdetection)中,在NMS(non-Maximumnon maximum suppression)中有如下代码:
if bboxes.numel() == 0: bboxes = multi_bboxes.new_zeros((0, 5)) labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
if torch.onnx.is_in_onnx_export(): raise RuntimeError('[ONNX Error] Can not record NMS ' 'as it has not been executed this time') return bboxes, labels
dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)
代码逻辑很简单,如果之前的网络根本没有输出任何合法的bbox(第一行的分支判断),那么显然nms的结果就是一堆0,所以没必要运行nms直接返回0就可以。
如果我们想将这段代码转换到ONNX,之前我们提到过ONNX不能处理分支逻辑,因此只能选择一条路去走,记录那条路转换得到的模型。很显然,正常情况下我们自然期待会有较多的bbox,并且将这些bbox作为参数调用nms。
所以如果我们发现模型执行的路径触发了if分支,我们必须要进行一个判断,看看是不是在转ONNX,如果是的话我们就需要直接报错,因为显然转出来的ONNX不是我们想要的。
假设什么都不做,在这种情况下我们转出来的模型是什么样呢?思考一下不难发现,假设函数的返回值就是网络的最终输出,那么我们只会得到一个2个节点的DAG,即第2,3行的两个操作。之前说过ONNX拿到所有的DAG之后会做剪枝,在这里ONNX拿到返回值(bboxes, labels)做回溯,发现最头上就是第2,3行的两个操作,就直接停掉了。所有其他的操作,比如backbone,rpn,fpn,都会被扔掉。
因此,在进行MMDet模型的转换的时候,必须用真实的数据和训练好的参数来做转换,否则基本不会得到有效的bbox,于是就会触发第6行的error

02

利用tracing机制做优化

在MMSeg中有一个很巧妙的利用tracing机制做优化的例子。
在slide inference时,我们需要计算一个countmat矩阵,这个矩阵在h, w以及对应的stride都固定的情况下会是一个常量。
不过在训练时,往往这些都是我们要调的参数,所有MMSeg没有选择把这些常数保存下来,而是每次都算一遍
count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) for h_idx in range(h_grids): for w_idx in range(w_grids): y1 = h_idx * h_stride x1 = w_idx * w_stride y2 = min(y1 + h_crop, h_img) x2 = min(x1 + w_crop, w_img) y1 = max(y2 - h_crop, 0) x1 = max(x2 - w_crop, 0) crop_img = img[:, :, y1:y2, x1:x2] crop_seg_logit = self.encode_decode(crop_img, img_meta) preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1 assert (count_mat == 0).sum() == 0 if torch.onnx.is_in_onnx_export(): # cast count_mat to constant while exporting to ONNX count_mat = torch.from_numpy( count_mat.cpu().detach().numpy()).to(device=img.device)
不过在部署时,这些参数往往是固定的,因此我们没必要把它算一遍。因此在倒数第4行的if分支里,我们做了一件看似很没用的事
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)

即我们把算出来的count_mat从tensor转换成numpy,再转回tensor。

其实我们的目的是切断tracing。
之前提到过,ONNX只能记录ATen相关的操作,但是很显然,tensor和numpy的互转肯定不是ATen操作。因此在回溯的时候,当访问到countmat,ONNX并不能发现它是被谁运算出来的,所以count_mat就会被看作一个常数被保存下来,之前计算count_mat的部分都会被扔掉。

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。

直播预告



“他山之石”历史文章




分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存