其他
【他山之石】pytorch计算模型FLOPs和Params
“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。
地址:https://www.zhihu.com/people/zhuang-ming-xi-31-16
01
pip install thop(推荐用这个) 或者 pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git(这个方法需要同时安装pytorch)
用法:
from torchvision.models import resnet50from thop import profile
model = resnet50()
flops, params = profile(model, input_size=(1, 3, 224,224))
print('FLOPs = ' + str(flops/1000**3) + 'G')
print('Params = ' + str(params/1000**2) + 'M')
class YourModule(nn.Module):
model = YourModule()
flops, params = profile(model, input_size=(1, 3, 224,224))
print('FLOPs = ' + str(flops/1000**3) + 'G')
print('Params = ' + str(params/1000**2) + 'M')
class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule
flops, params = profile(model, input_size=(1, 3, 224,224),
custom_ops={YourModule: count_your_model})
print('FLOPs = ' + str(flops/1000**3) + 'G')
print('Params = ' + str(params/1000**2) + 'M')
02
pip install ptflops 或者 pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git
import torchvision.models as models
import torch
from ptflops import get_model_complexity_info
with torch.cuda.device(0):
net = models.densenet161()
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
print_per_layer_stat=True, verbose=True)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
“他山之石”历史文章
保姆级教程:个人深度学习工作站配置指南
整理 Deep Learning 调参 tricks
Tensorflow模型保存方式大汇总
利用Tensorflow构建CNN图像多分类模型及图像参数、数据维度变化情况实例分析
pytorch中optimizer对loss的影响
使用PyTorch 1.6 for Android
神经网络解微分方程实例:三体问题
pytorch 实现双边滤波
编译PyTorch静态库
工业界视频理解解决方案大汇总
动手造轮子-rnn
凭什么相信你,我的CNN模型?关于CNN模型可解释性的思考
c++接口libtorch介绍& vscode+cmake实践
python从零开始构建知识图谱
更多他山之石专栏文章,
请点击文章底部“阅读原文”查看
分享、点赞、在看,给个三连击呗!