快来解锁PyTorch新技能:torch.fix
设为星标,干货直达!
PyTorch1.9版本正式发布了torch.fx预览版,而在1.10版本发布了稳定版本,torch.fx这个工具包的主要功能是实现对nn.Module实例的变换,或者说用来操作模型。听起来这个功能有点怪,但是如果你深入了解的话,你会发现其实它的用处还不少。这里先简单介绍一下torch.fx核心概念,然后介绍它的几个具体应用实例。
torch.fx主要有3个组件:符号追踪器(symbolic tracer),中间表示(intermediate representation), Python代码生成(Python code generation)。为了快速理解,这里给出一个简单的例子:
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# 符号追踪这个模块
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
# 中间表示
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# 生成代码
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
符号追踪器对模块的forward代码进行符号执行,它送入的是假的输入,叫做Proxies,代码中的所有operations都被记录下来。这个过程和TensorFlow构建静态图有点类似,Proxies类似placeholder。
这个追踪最终可以得到代码计算图的中间表示:torch.fx.Graph,和TensorFlow的Graph基本是相同的概念。Graph记录了所有的operations,具体的,一个Graph包括一系列的torch.fx.Node,Node是Graph的最基本单元,它对应的是一个operation,Node.op记录的具体的操作类型,主要包括以下几种类型:placeholder,get_attr,call_function,call_module,call_method,output。这里通过一个例子来理解这几个具体的类型:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
# 打印graph的所有node
gm.graph.print_tabular()
opcode name target args kwargs
------------- ------------- ------------------------------------------------------- ------------------ -----------
placeholder x x () {}
get_attr linear_weight linear.weight () {}
call_function add <built-in function add> (x, linear_weight) {}
call_module linear linear (add,) {}
call_method relu relu (linear,) {}
call_function sum_1 <built-in method sum of type object at 0x7f98c972d200> (relu,) {'dim': -1}
call_function topk <built-in method topk of type object at 0x7f98c972d200> (sum_1, 3) {}
output output output (topk,) {}
可以看到每个Node除了op之外,还有name,target,args和kwargs,对于不同的op其含义有点区别。placeholder其实就是graph的输入,而output是graph的输出,它们的target和name一样。get_attr其实就是获取module的参数;call_function是调用函数,它的target指明了具体的函数;call_module是调用子module,target就是子module名;call_method是调用torch的函数。args和kwargs其实就是op对应的tuple和dict参数,可以看到很多ops的args其实就是其它Node的name,所以这样各个Node就是建立了联系,从而构成了Graph。
最后一个组件,就是用于Python代码生成,就是根据Graph的语义自动生成相应的执行代码。
直观看起来,torch.fx做的就是将一个Module转换为静态图,这和转换Module有什么关系。试想一下,如果我们将一个Module追踪得到的Graph进行变换,加上Python代码生成工具,是不是就可以到变换一个Module的目的。这整个流程就是:symbolic tracing -> intermediate representation -> transforms -> Python code generation。这就实现了一个Module到另外一个Module的Python-to-Python转换流程。整个代码流程如下所示:
import torch
import torch.fx
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# 首先得到模块的graph
# Step 1: Acquire a Graph representing the code in `m`
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to
# fx.Tracer.trace and constructing a GraphModule. We'll
# split that out in our transform to allow the caller to
# customize tracing behavior.
graph : torch.fx.Graph = tracer_class().trace(m)
# 然后对graph做一些修改操作
# Step 2: Modify this Graph or create a new one
graph = ...
# 最后用新得到的graph构建新的模块
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
这里最终得到的torch.fx.GraphModule除了包含graph和code属性外就和正常的nn.Module一样,它的forward执行的就是graph的语义代码。这里来看一个修改Module的简单例子,这个例子中我们将模块中所有的torch.add()操作替换成** **torch.mul() :
import torch
import torch.fx
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of
# nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.add:
node.target = torch.mul
graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return fx.GraphModule(m, graph)
最后,torch.fx的应用场景是什么呢,其实还真没有不少需要用来变换模块的场景。目前,官方文档里面已经给出了一些具体的应用案例:
Replace one op Conv/Batch Norm fusion replace_pattern: Basic usage Quantization Invert Transformation
比如Conv/Batch Norm fusion,我们知道在推理阶段将BN融合到Conv里合成一个操作可以加速推理速度,那么torch.fx就很容易实现这个功能,具体的代码实现如下:
# Works for length 2 patterns with 2 modules
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
if len(node.args) == 0:
return False
nodes: Tuple[Any, fx.Node] = (node.args[0], node)
for expected_type, current_node in zip(pattern, nodes):
if not isinstance(current_node, fx.Node):
return False
if current_node.op != 'call_module':
return False
if not isinstance(current_node.target, str):
return False
if current_node.target not in modules:
return False
if type(modules[current_node.target]) is not expected_type:
return False
return True
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
assert(isinstance(node.target, str))
parent_name, name = _parent_name(node.target)
modules[node.target] = new_module
setattr(modules[parent_name], name, new_module)
def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
"""
Fuses convolution/BN layers for inference purposes. Will deepcopy your
model by default, but can modify the model inplace as well.
"""
patterns = [(nn.Conv1d, nn.BatchNorm1d),
(nn.Conv2d, nn.BatchNorm2d),
(nn.Conv3d, nn.BatchNorm3d)]
if not inplace:
model = copy.deepcopy(model)
fx_model = fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())
new_graph = copy.deepcopy(fx_model.graph)
for pattern in patterns:
for node in new_graph.nodes:
# 找到目标Node:args是Conv,target是BN
if matches_module_pattern(pattern, node, modules):
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
continue
conv = modules[node.args[0].target]
bn = modules[node.target]
# 融合BN和Conv
fused_conv = fuse_conv_bn_eval(conv, bn)
# 替换Node的module,其实就是将融合后的module替换Conv Node的target,背后是模块替换
replace_node_module(node.args[0], modules, fused_conv)
# 将所有用到BN Node的替换为Conv Node(已经融合后的Conv)
node.replace_all_uses_with(node.args[0])
# 删除BN Node
new_graph.erase_node(node)
return fx.GraphModule(fx_model, new_graph)
from torch.fx.experimental.optimization import fuse
from torchvision.models import resnet18
model = resnet18()
model.eval() # 必须在eval模型下fuse
'''
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
'''
fused_model = fuse(model)
'''
(layer4): Module(
(0): Module(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(downsample): Module(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
)
)
(1): Module(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
'''
torch.fx还有一个比较实用的使用场景,那就是对模型进行特征提取,比如我们希望得到模型中间特征用来分析,或者用一些中间特征用于构建其它模型,比如检测和分割模型。目前torchvision已经支持了这项功能:feature_extraction。这里也给出一个简单的代码例子,比如我们要提取ResNet模型的C4和C5特征:
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
# 构建模型
model = torchvision.models.resnet50()
# 获取模型的所有的nodes
train_nodes, eval_nodes = get_graph_node_names(model)
# 定义输出node
return_nodes = {'layer3.5.relu_2': 'C4', 'layer4.2.relu_2': 'C5'}
# 进行重建
n_model = create_feature_extractor(model, return_nodes)
out = model(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])
[('C4', torch.Size([1, 1024, 14, 14])), ('C5', torch.Size([1, 2048, 7, 7]))]
注意这里新得到的model会移除所有和输出Node不相关的Node,比如这里的最后的fc层是没有的。
推荐阅读
谷歌AI用30亿数据训练了一个20亿参数Vision Transformer模型,在ImageNet上达到新的SOTA!
"未来"的经典之作ViT:transformer is all you need!
PVT:可用于密集任务backbone的金字塔视觉transformer!
涨点神器FixRes:两次超越ImageNet数据集上的SOTA
不妨试试MoCo,来替换ImageNet上pretrain模型!
机器学习算法工程师
一个用心的公众号