torch.fx主要有3个组件:符号追踪器(symbolic tracer),中间表示(intermediate representation), Python代码生成(Python code generation)。为了快速理解,这里给出一个简单的例子:
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self):
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# 符号追踪这个模块
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
# 中间表示
# High-level intermediate representation (IR) - Graph representation
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
# 生成代码
# Code generation - valid Python code
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
# 打印graph的所有node
opcode name target args kwargs
------------- ------------- ------------------------------------------------------- ------------------ -----------
placeholder x x () {}
get_attr linear_weight linear.weight () {}
call_function add <built-in function add> (x, linear_weight) {}
call_module linear linear (add,) {}
call_method relu relu (linear,) {}
call_function sum_1 <built-in method sum of type object at 0x7f98c972d200> (relu,) {'dim': -1}
call_function topk <built-in method topk of type object at 0x7f98c972d200> (sum_1, 3) {}
output output output (topk,) {}
直观看起来,torch.fx做的就是将一个Module转换为静态图,这和转换Module有什么关系。试想一下,如果我们将一个Module追踪得到的Graph进行变换,加上Python代码生成工具,是不是就可以到变换一个Module的目的。这整个流程就是:symbolic tracing -> intermediate representation -> transforms -> Python code generation。这就实现了一个Module到另外一个Module的Python-to-Python转换流程。整个代码流程如下所示:
import torch
import torch.fx
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# 首先得到模块的graph
# Step 1: Acquire a Graph representing the code in `m`
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to
# fx.Tracer.trace and constructing a GraphModule. We'll
# split that out in our transform to allow the caller to
# customize tracing behavior.
graph : torch.fx.Graph = tracer_class().trace(m)
# 然后对graph做一些修改操作
# Step 2: Modify this Graph or create a new one
graph = ...
# 最后用新得到的graph构建新的模块
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
这里最终得到的torch.fx.GraphModule除了包含graph和code属性外就和正常的nn.Module一样,它的forward执行的就是graph的语义代码。这里来看一个修改Module的简单例子,这个例子中我们将模块中所有的torch.add()操作替换成** **torch.mul() :
import torch
import torch.fx
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of
# nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.add:
node.target = torch.mul
graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return fx.GraphModule(m, graph)
Replace one op Conv/Batch Norm fusion replace_pattern: Basic usage Quantization Invert Transformation
比如Conv/Batch Norm fusion,我们知道在推理阶段将BN融合到Conv里合成一个操作可以加速推理速度,那么torch.fx就很容易实现这个功能,具体的代码实现如下:
# Works for length 2 patterns with 2 modules
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
if len(node.args) == 0:
return False
nodes: Tuple[Any, fx.Node] = (node.args[0], node)
for expected_type, current_node in zip(pattern, nodes):
if not isinstance(current_node, fx.Node):
return False
if current_node.op != 'call_module':
return False
if not isinstance(current_node.target, str):
return False
if current_node.target not in modules:
return False
if type(modules[current_node.target]) is not expected_type:
return False
return True
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
assert(isinstance(node.target, str))
parent_name, name = _parent_name(node.target)
modules[node.target] = new_module
setattr(modules[parent_name], name, new_module)
def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
Fuses convolution/BN layers for inference purposes. Will deepcopy your
model by default, but can modify the model inplace as well.
patterns = [(nn.Conv1d, nn.BatchNorm1d),
(nn.Conv2d, nn.BatchNorm2d),
(nn.Conv3d, nn.BatchNorm3d)]
if not inplace:
model = copy.deepcopy(model)
fx_model = fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())
new_graph = copy.deepcopy(fx_model.graph)
for pattern in patterns:
for node in new_graph.nodes:
# 找到目标Node:args是Conv,target是BN
if matches_module_pattern(pattern, node, modules):
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
conv = modules[node.args[0].target]
bn = modules[node.target]
# 融合BN和Conv
fused_conv = fuse_conv_bn_eval(conv, bn)
# 替换Node的module,其实就是将融合后的module替换Conv Node的target,背后是模块替换
replace_node_module(node.args[0], modules, fused_conv)
# 将所有用到BN Node的替换为Conv Node(已经融合后的Conv)
# 删除BN Node
return fx.GraphModule(fx_model, new_graph)
from torch.fx.experimental.optimization import fuse
from torchvision.models import resnet18
model = resnet18()
model.eval() # 必须在eval模型下fuse
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
fused_model = fuse(model)
(layer4): Module(
(0): Module(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(downsample): Module(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
(1): Module(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
# 构建模型
model = torchvision.models.resnet50()
# 获取模型的所有的nodes
train_nodes, eval_nodes = get_graph_node_names(model)
# 定义输出node
return_nodes = {'layer3.5.relu_2': 'C4', 'layer4.2.relu_2': 'C5'}
# 进行重建
n_model = create_feature_extractor(model, return_nodes)
out = model(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])
[('C4', torch.Size([1, 1024, 14, 14])), ('C5', torch.Size([1, 2048, 7, 7]))]
