【PyTorch】torch.nn.Module 源码分析
点击上方“MLNLP”,选择“星标”公众号
重磅干货,第一时间送达
作者 | 药师
地址 | https://zhuanlan.zhihu.com/p/88712978
专栏 | 我的机器学习笔记
【PyTorch】torch.nn.Module 源码分析
torch.nn.Module
这个类的内部有多达 48 个函数,这个类是 PyTorch
中所有 neural network module
的基类,自己创建的网络模型都是这个类的子类,下边是一个示例。这篇文章就和大家一起来阅读一下这个 base class
。
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
首先是__init__
和forward
这两个函数。__init__
中主要是初始化一些内部需要用到的state
;forward
在这里没有具体实现,是需要在各个子类中实现的,如果子类中没有实现就会报错raise NotImplementedError
。
函数cuda
和cpu
比较简单。函数cuda
的作用是Moves all model parameters and buffers to the GPU.
;函数cpu
的作用是Moves all model parameters and buffers to the CPU.
。两者返回的都是Module
本身且都调用了_apply
函数。
def cuda(self, device=None):
return self._apply(lambda t: t.cuda(device))
def cpu(self):
return self._apply(lambda t: t.cpu())
接下来看一下函数_apply
。首先通过循环来实现对所有子模型都遍历一遍该函数内的操作。接下来的这个循环是遍历self._parameters
,然后函数compute_should_use_set_data
用来决定是否change the tensor in-place
,即原地修改tensor
。如果是原地修改,将原来的用新的代替就好;否则就在字典self._parameters
中把新的tensor
注册。如果参数值param
有梯度param.grad
,那么对param.grad
也要做相同的操作。最后一个循环就是对字典self._buffers
中的tensor
做一个CPU
和GPU
之间的迁移,并将修改后的tensor
重新存放到self._buffers
中。最后将Module
本身返回。
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def compute_should_use_set_data(tensor, tensor_applied):
# ...
for key, param in self._parameters.items():
if param is not None:
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with torch.no_grad():
param_applied = fn(param)
should_use_set_data = compute_should_use_set_data(param, param_applied)
if should_use_set_data:
param.data = param_applied
else:
assert isinstance(param, Parameter)
assert param.is_leaf
self._parameters[key] = Parameter(param_applied, param.requires_grad)
if param.grad is not None:
with torch.no_grad():
grad_applied = fn(param.grad)
should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
if should_use_set_data:
param.grad.data = grad_applied
else:
assert param.grad.is_leaf
self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
有了_apply
这个函数,就可以很方便地做一些操作,比如函数share_memory
就调用了函数_apply
。作用就是将所有tensor
进行一遍share_memory_
操作,即Moves the underlying storage to shared memory. This is a no-op if the underlying storage is already in shared memory and for CUDA tensors. Tensors in shared memory cannot be resized.
,简而言之就是将tensor
转移到共享内存shared memory
中去。
def share_memory(self):
return self._apply(lambda t: t.share_memory_())
现在来看一下apply
函数(注意和上边的_apply
函数区分)。这个函数很简单就是将Module
及其所有的SubModule
传进给定的fn
函数操作一遍。举个例子,我们可以用这个函数来对Module
的网络模型参数用指定的方法初始化。
def apply(self, fn):
for module in self.children():
module.apply(fn)
fn(self)
return self
下边这个例子就是将网络模型net
中的子模型Linear
的参数全部赋值为 1 。
Example::
>>> def init_weights(m):
>>> print(m)
>>> if type(m) == nn.Linear:
>>> m.weight.data.fill_(1.0)
>>> print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
下边看下type
函数、float
函数、double
函数以及half
函数。type
函数是将所有parameters
和buffers
都转成指定的目标类型dst_type
;float
、double
和half
这三个函数是将所有floating point parameters
分别转成float datatype
、double datatype
和half datatype
。torch.Tensor.float
即torch.float32
;torch.Tensor.double
即torch.float64
;torch.Tensor.half
即torch.float16
。
def type(self, dst_type):
return self._apply(lambda t: t.type(dst_type))
def float(self):
return self._apply(lambda t: t.float() if t.is_floating_point() else t)
def double(self):
return self._apply(lambda t: t.double() if t.is_floating_point() else t)
def half(self):
return self._apply(lambda t: t.half() if t.is_floating_point() else t)
函数to
的作用是原地 ( in-place ) 修改Module
,它可以当成三种函数来使用:function:: to(device=None, dtype=None, non_blocking=False)
; function:: to(dtype, non_blocking=False)
; function:: to(tensor, non_blocking=False)
。下边展示的是使用方法。
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]], dtype=torch.float64)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16)
到这里就已经介绍了
函数state_dict
的作用是返回一个包含module
的所有state
的dictionary
,而这个字典的Keys
对应的就是parameter
和buffer
的名字names
。该函数的源码部分有一个循环可以递归遍历Module
中所有的SubModule
。
>>> net = torch.nn.Linear(2, 2)
>>> net.state_dict()
OrderedDict([('weight', tensor([[-0.3558, 0.2153],
[-0.2785, 0.6982]])), ('bias', tensor([ 0.5771, -0.6232]))])
>>> net.state_dict().keys()
odict_keys(['weight', 'bias'])
>>> net = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 2))
>>> net.state_dict()
OrderedDict([('0.weight', tensor([[ 0.4792, 0.5772], [ 0.1039, -0.0552]])),
('0.bias', tensor([-0.5175, -0.6469])),
('1.weight', tensor([[-0.5346, -0.0173], [-0.2092, 0.0794]])),
('1.bias', tensor([-0.2150, 0.2323]))])
>>> net.state_dict().keys()
odict_keys(['0.weight', '0.bias', '1.weight', '1.bias'])
函数load_state_dict
的作用和上边介绍的state_dict
的作用刚好相反,是将parameter
和buffer
加载到Module
及其SubModule
中去。
对于函数parameters
,我们可以使用for param in model.parameters()
来遍历网络模型中的参数,因为该函数返回的是一个迭代器iterator
。我们在使用优化算法的时候就是将model.parameters()
传给优化器Optimizer
。与之类似的还有函数buffers
、函数children
和函数modules
。
def parameters(self, recurse=True):
for name, param in self.named_parameters(recurse=recurse):
yield param
def buffers(self, recurse=True):
for name, buf in self.named_buffers(recurse=recurse):
yield buf
def children(self):
for name, module in self.named_children():
yield module
def modules(self):
for name, module in self.named_modules():
yield module
与之相对应的,也有四个函数:named_parameters
、named_buffers
、named_children
和named_modules
。函数返回一个迭代器,包括names
和members
。
def _named_members(self, get_members_fn, prefix='', recurse=True):
r"""Helper method for yielding various names + members of modules."""
memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield name, v
def named_parameters(self, prefix='', recurse=True):
gen = self._named_members(
lambda module: module._parameters.items(),
prefix=prefix, recurse=recurse)
for elem in gen:
yield elem
def named_buffers(self, prefix='', recurse=True):
gen = self._named_members(
lambda module: module._buffers.items(),
prefix=prefix, recurse=recurse)
for elem in gen:
yield elem
def named_children(self):
memo = set()
for name, module in self._modules.items():
if module is not None and module not in memo:
memo.add(module)
yield name, module
def named_modules(self, memo=None, prefix=''):
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in module.named_modules(memo, submodule_prefix):
yield m
至此,又介绍了
函数train
和函数eval
的作用是将Module
及其SubModule
分别设置为training mode
和evaluation mode
。这两个函数只对特定的Module
有影响,例如Class Dropout
、Class BatchNorm
。
def train(self, mode=True):
self.training = mode
for module in self.children():
module.train(mode)
return self
def eval(self):
return self.train(False)
函数requires_grad_
用于设置self.parameters()
是否需要record
梯度,默认情况下是True
。函数zero_grad
用于设置self.parameters()
的gradients
为零。
def requires_grad_(self, requires_grad=True):
for p in self.parameters():
p.requires_grad_(requires_grad)
return self
def zero_grad(self):
for p in self.parameters():
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
函数_get_name
、extra_repr
、__repr__
以及__dir__
都是用于输出Module
的相关信息的。_get_name
返回的是Module
类的名字;extra_repr
是用于torch.nn.Module
的子类来具体实现,用于输出module
的信息,可以输出一行或者多行的字符串信息,具体示例如下所示;__repr__
用于输出该Module
中所有SubModule
的信息并且one item per line
;__dir__
用于输出该Module
中包含的所有self.__class__
、self.__dict__.keys()
、self._parameters.keys()
、self._modules.keys()
以及self._buffers.keys()
,并且会通过key for key in keys if not key[0].isdigit()
来消除不合法的Python
变量名称的属性。
def _get_name(self):
return self.__class__.__name__
def extra_repr(self):
return ''
def __repr__(self):
extra_lines = []
extra_repr = self.extra_repr()
if extra_repr: # empty string will be split into list ['']
extra_lines = extra_repr.split('\n')
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
lines = extra_lines + child_lines
main_str = self._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str
def __dir__(self):
module_attrs = dir(self.__class__)
attrs = list(self.__dict__.keys())
parameters = list(self._parameters.keys())
modules = list(self._modules.keys())
buffers = list(self._buffers.keys())
keys = module_attrs + attrs + parameters + modules + buffers
# Eliminate attrs that are not legal Python variable names
keys = [key for key in keys if not key[0].isdigit()]
return sorted(keys)
# --------------------------
# torch.nn.Linear -- class Linear(Module)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
# Example
>>> l = torch.nn.Linear(2, 2)
>>> l.extra_repr()
'in_features=2, out_features=2, bias=True'
至此又介绍了
__setstate__
设置state
,如果self.__dict__
中找不到_forward_pre_hooks
、_state_dict_hooks
、_load_state_dict_pre_hooks
,那么就在self
中定义这三个变量为OrderedDict
。
def __setstate__(self, state):
self.__dict__.update(state)
# Support loading old checkpoints that don't have the following attrs:
if '_forward_pre_hooks' not in self.__dict__:
self._forward_pre_hooks = OrderedDict()
if '_state_dict_hooks' not in self.__dict__:
self._state_dict_hooks = OrderedDict()
if '_load_state_dict_pre_hooks' not in self.__dict__:
self._load_state_dict_pre_hooks = OrderedDict()
__getattr__
用于获取给定name
的Module
类中的成员。首先从self.__dict__['_parameters']
、self.__dict__['_buffers']
以及self.__dict__['_modules']
中查找,找到后将其return
;若找不到,则调用raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
报错。
def __getattr__(self, name):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
return _parameters[name]
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name]
if '_modules' in self.__dict__:
modules = self.__dict__['_modules']
if name in modules:
return modules[name]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))
__setattr__(self, name, value)
用于设置属性,即首先从self.__dict__.get('_parameters')
、self.__dict__.get('_buffers')
以及self.__dict__.get('_modules')
中查找,找到后则将该Key-Value
删除,按照给定的name
和value
重新register
。
__delattr__
用于删除给定name
的Module
类中的成员。首先从self._parameters
、self._buffers
以及self._modules
中查找,找到后使用del
将其删除;若找不到,则调用object.__delattr__(self, name)
进行删除。
def __delattr__(self, name):
if name in self._parameters:
del self._parameters[name]
elif name in self._buffers:
del self._buffers[name]
elif name in self._modules:
del self._modules[name]
else:
object.__delattr__(self, name)
_save_to_state_dict(self, destination, prefix, keep_vars)
的作用是将module state
储存到destination
,并且只针对该module
,所以这个函数一般是被module
中的所有SubModule
调用。This is called on every submodule in method ~ torch.nn.Module.state_dict
。_load_from_state_dict
的作用与之相反,是用来加载module
的,相同的是也只针对该module
,所以这个函数通常是被module
中的所有SubModule
调用。This is called on every submodule in method ~ torch.nn.Module.load_state_dict
。参数prefix
表示的是该Module
中parameters
和buffers
的前缀。
def _save_to_state_dict(self, destination, prefix, keep_vars):
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.data
for name, buf in self._buffers.items():
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.data
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
local_state = {k: v.data for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
# ...
elif strict:
missing_keys.append(key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix):
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
函数_register_state_dict_hook
和函数_register_load_state_dict_pre_hook
的作用也很简单,直接看代码注释就知道了。hooks
即文件 torch.utils.hooks ,文件中的类Class RemovableHandle
的作用是A handle which provides the capability to remove a hook
。
def _register_state_dict_hook(self, hook):
r"""These hooks will be called with arguments: `self`, `state_dict`,
`prefix`, `local_metadata`, after the `state_dict` of `self` is set.
Note that only parameters and buffers of `self` or its children are
guaranteed to exist in `state_dict`. The hooks may modify `state_dict`
inplace or return a new one.
"""
handle = hooks.RemovableHandle(self._state_dict_hooks)
self._state_dict_hooks[handle.id] = hook
return handle
def _register_load_state_dict_pre_hook(self, hook):
r"""These hooks will be called with arguments: `state_dict`, `prefix`,
`local_metadata`, `strict`, `missing_keys`, `unexpected_keys`,
`error_msgs`, before loading `state_dict` into `self`. These arguments
are exactly the same as those of `_load_from_state_dict`.
"""
handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
self._load_state_dict_pre_hooks[handle.id] = hook
return handle
函数_tracing_name
主要是被函数_slow_forward
调用。函数_slow_forward
和函数__call__
作用相似,都是在利用函数forward
做计算。
def _tracing_name(self, tracing_state):
if not tracing_state._traced_module_stack:
return None
module = tracing_state._traced_module_stack[-1]
for name, child in module.named_children():
if child is self:
return name
return None
def _slow_forward(self, *input, **kwargs):
tracing_state = torch._C._get_tracing_state()
if not tracing_state:
return self.forward(*input, **kwargs)
if not hasattr(tracing_state, '_traced_module_stack'):
tracing_state._traced_module_stack = []
name = self._tracing_name(tracing_state)
if name:
tracing_state.push_scope('%s[%s]' % (self._get_name(), name))
else:
tracing_state.push_scope(self._get_name())
tracing_state._traced_module_stack.append(self)
try:
result = self.forward(*input, **kwargs)
finally:
tracing_state.pop_scope()
tracing_state._traced_module_stack.pop()
return result
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
register_parameter
、register_buffer
和add_module
这三个函数可以放一起看。函数register_parameter
的作用就是将给定的name - param
加入到字典self._parameters
中去。函数register_buffer
通常用于register
那些不属于model parameters
的属性,例如BatchNorm
的running_mean
就不是parameter
。函数add_module
的作用是给当前Module
按照传递进来的参数对name - module
添加子模块SubModule
。
def register_parameter(self, name, param):
# Check AttributeError TypeError KeyError ...
if param is None:
self._parameters[name] = None
elif not isinstance(param, Parameter):
# raise TypeError ...
elif param.grad_fn:
# raise ValueError ... Cannot assign non-leaf Tensor to parameter
else:
self._parameters[name] = param
def register_buffer(self, name, tensor):
r"""Example:
>>> self.register_buffer('running_mean', torch.zeros(num_features))
"""
if '_buffers' not in self.__dict__:
# raise AttributeError ... Cannot assign buffer before Module.__init__() call
elif not isinstance(name, torch._six.string_classes):
# raise TypeError ... Buffer name should be a string
elif '.' in name:
# raise KeyError ... Buffer name can't contain '.'
elif name == '':
# raise KeyError ... Buffer name can't be empty string
elif hasattr(self, name) and name not in self._buffers:
# raise KeyError ... Attribute already exists
elif tensor is not None and not isinstance(tensor, torch.Tensor):
# raise TypeError
else:
self._buffers[name] = tensor
def add_module(self, name, module):
# Check TypeError KeyError ...
self._modules[name] = module
函数register_backward_hook
、register_forward_pre_hook
以及register_forward_hook
的作用与前边介绍的函数_register_state_dict_hook
和函数_register_load_state_dict_pre_hook
的作用类似,也是在该Module
中register
一个hook
。官方给出的说明是:【register_backward_hook will be called every time the gradients with respect to module inputs are computed. register_forward_pre_hook will be called every time before :func: forward is invoked. register_forward_hook will be called every time after :func: forward has computed an output.
】。
def register_backward_hook(self, hook):
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
def register_forward_pre_hook(self, hook):
handle = hooks.RemovableHandle(self._forward_pre_hooks)
self._forward_pre_hooks[handle.id] = hook
return handle
def register_forward_hook(self, hook):
handle = hooks.RemovableHandle(self._forward_hooks)
self._forward_hooks[handle.id] = hook
return handle
以上就是 torch.nn.Module 中的全部函数介绍(这个类的代码竟然长达一千多行)。
阅读这些源码,我最大的感触就是一定要有耐心读下去,坚持下去就可以。
推荐阅读: