查看原文
其他

【他山之石】pytorch计算模型FLOPs和Params

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者:知乎—庄六岁

地址:https://www.zhihu.com/people/zhuang-ming-xi-31-16

网络框架模型计算量影响到模型的推断时间,模型的参数量对设备内存有要求,为了进行模型比较给大家介绍两种计算FLOPs和Params实用的小工具:

01

第一种方法
https://github.com/Lyken17/pytorch-OpCounter
安装方法:
pip install thop(推荐用这个) 或者 pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git(这个方法需要同时安装pytorch)

用法:

from torchvision.models import resnet50from thop import profilemodel = resnet50()flops, params = profile(model, input_size=(1, 3, 224,224))print('FLOPs = ' + str(flops/1000**3) + 'G')print('Params = ' + str(params/1000**2) + 'M')
如果要使用自己的模型则:
class YourModule(nn.Module):model = YourModule()flops, params = profile(model, input_size=(1, 3, 224,224))print('FLOPs = ' + str(flops/1000**3) + 'G')print('Params = ' + str(params/1000**2) + 'M')
如果需要自己定义规则则:
class YourModule(nn.Module):# your definitiondef count_your_model(model, x, y):# your ruleflops, params = profile(model, input_size=(1, 3, 224,224),custom_ops={YourModule: count_your_model})print('FLOPs = ' + str(flops/1000**3) + 'G')print('Params = ' + str(params/1000**2) + 'M')

02

第二种方法
https://github.com/sovrasov/flops-counter.pytorch
安装方法:
pip install ptflops 或者 pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git
使用方法:
import torchvision.models as modelsimport torchfrom ptflops import get_model_complexity_info
with torch.cuda.device(0): net = models.densenet161() macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, print_per_layer_stat=True, verbose=True) print('{:<30} {:<8}'.format('Computational complexity: ', macs)) print('{:<30} {:<8}'.format('Number of parameters: ', params))
其中print_per_layer_stat用来管理是否输出每一层的参数量和计算量。
这两种方法计算结果相差不大,亲测简洁又好用

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。



“他山之石”历史文章


更多他山之石专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存