查看原文
其他

【PyTorch】torch.nn.Module 源码分析

点击上方“MLNLP”,选择“星标”公众号

重磅干货,第一时间送达

作者 | 药师

地址 | https://zhuanlan.zhihu.com/p/88712978

专栏 | 我的机器学习笔记

【PyTorch】torch.nn.Module 源码分析



torch.nn.Module 这个类的内部有多达 48 个函数,这个类是 PyTorch 中所有 neural network module 的基类,自己创建的网络模型都是这个类的子类,下边是一个示例。这篇文章就和大家一起来阅读一下这个 base class 。

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))

首先是__init__forward这两个函数。__init__中主要是初始化一些内部需要用到的stateforward在这里没有具体实现,是需要在各个子类中实现的,如果子类中没有实现就会报错raise NotImplementedError

函数cudacpu比较简单。函数cuda的作用是Moves all model parameters and buffers to the GPU.;函数cpu的作用是Moves all model parameters and buffers to the CPU.。两者返回的都是Module本身且都调用了_apply函数。

def cuda(self, device=None):
return self._apply(lambda t: t.cuda(device))

def cpu(self):
return self._apply(lambda t: t.cpu())

接下来看一下函数_apply。首先通过循环来实现对所有子模型都遍历一遍该函数内的操作。接下来的这个循环是遍历self._parameters,然后函数compute_should_use_set_data用来决定是否change the tensor in-place,即原地修改tensor。如果是原地修改,将原来的用新的代替就好;否则就在字典self._parameters中把新的tensor注册。如果参数值param有梯度param.grad,那么对param.grad也要做相同的操作。最后一个循环就是对字典self._buffers中的tensor做一个CPUGPU之间的迁移,并将修改后的tensor重新存放到self._buffers中。最后将Module本身返回。

def _apply(self, fn):
for module in self.children():
module._apply(fn)

def compute_should_use_set_data(tensor, tensor_applied):
# ...

for key, param in self._parameters.items():
if param is not None:
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with torch.no_grad():
param_applied = fn(param)
should_use_set_data = compute_should_use_set_data(param, param_applied)
if should_use_set_data:
param.data = param_applied
else:
assert isinstance(param, Parameter)
assert param.is_leaf
self._parameters[key] = Parameter(param_applied, param.requires_grad)

if param.grad is not None:
with torch.no_grad():
grad_applied = fn(param.grad)
should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
if should_use_set_data:
param.grad.data = grad_applied
else:
assert param.grad.is_leaf
self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)

for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)

return self

有了_apply这个函数,就可以很方便地做一些操作,比如函数share_memory就调用了函数_apply。作用就是将所有tensor进行一遍share_memory_操作,即Moves the underlying storage to shared memory. This is a no-op if the underlying storage is already in shared memory and for CUDA tensors. Tensors in shared memory cannot be resized.,简而言之就是将tensor转移到共享内存shared memory中去。

def share_memory(self):
return self._apply(lambda t: t.share_memory_())

现在来看一下apply函数(注意和上边的_apply函数区分)。这个函数很简单就是将Module及其所有的SubModule传进给定的fn函数操作一遍。举个例子,我们可以用这个函数来对Module的网络模型参数用指定的方法初始化。

def apply(self, fn):
for module in self.children():
module.apply(fn)
fn(self)
return self

下边这个例子就是将网络模型net中的子模型Linear的参数全部赋值为 1 。

Example::
>>> def init_weights(m):
>>> print(m)
>>> if type(m) == nn.Linear:
>>> m.weight.data.fill_(1.0)
>>> print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)

下边看下type函数、float函数、double函数以及half函数。type函数是将所有parametersbuffers都转成指定的目标类型dst_typefloatdoublehalf这三个函数是将所有floating point parameters分别转成float datatypedouble datatypehalf datatypetorch.Tensor.floattorch.float32torch.Tensor.doubletorch.float64torch.Tensor.halftorch.float16

def type(self, dst_type):
return self._apply(lambda t: t.type(dst_type))

def float(self):
return self._apply(lambda t: t.float() if t.is_floating_point() else t)

def double(self):
return self._apply(lambda t: t.double() if t.is_floating_point() else t)

def half(self):
return self._apply(lambda t: t.half() if t.is_floating_point() else t)

函数to的作用是原地 ( in-place ) 修改Module,它可以当成三种函数来使用:function:: to(device=None, dtype=None, non_blocking=False); function:: to(dtype, non_blocking=False); function:: to(tensor, non_blocking=False)。下边展示的是使用方法。

>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]], dtype=torch.float64)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16)

到这里就已经介绍了  个函数了。

函数state_dict的作用是返回一个包含module的所有statedictionary,而这个字典的Keys对应的就是parameterbuffer的名字names。该函数的源码部分有一个循环可以递归遍历Module中所有的SubModule

>>> net = torch.nn.Linear(2, 2)
>>> net.state_dict()
OrderedDict([('weight', tensor([[-0.3558, 0.2153],
[-0.2785, 0.6982]])), ('bias', tensor([ 0.5771, -0.6232]))])
>>> net.state_dict().keys()
odict_keys(['weight', 'bias'])

>>> net = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 2))
>>> net.state_dict()
OrderedDict([('0.weight', tensor([[ 0.4792, 0.5772], [ 0.1039, -0.0552]])),
('0.bias', tensor([-0.5175, -0.6469])),
('1.weight', tensor([[-0.5346, -0.0173], [-0.2092, 0.0794]])),
('1.bias', tensor([-0.2150, 0.2323]))])
>>> net.state_dict().keys()
odict_keys(['0.weight', '0.bias', '1.weight', '1.bias'])

函数load_state_dict的作用和上边介绍的state_dict的作用刚好相反,是将parameterbuffer加载到Module及其SubModule中去。

对于函数parameters,我们可以使用for param in model.parameters()来遍历网络模型中的参数,因为该函数返回的是一个迭代器iterator。我们在使用优化算法的时候就是将model.parameters()传给优化器Optimizer。与之类似的还有函数buffers、函数children和函数modules

def parameters(self, recurse=True):
for name, param in self.named_parameters(recurse=recurse):
yield param

def buffers(self, recurse=True):
for name, buf in self.named_buffers(recurse=recurse):
yield buf

def children(self):
for name, module in self.named_children():
yield module

def modules(self):
for name, module in self.named_modules():
yield module

与之相对应的,也有四个函数:named_parametersnamed_buffersnamed_childrennamed_modules。函数返回一个迭代器,包括namesmembers

def _named_members(self, get_members_fn, prefix='', recurse=True):
r"""Helper method for yielding various names + members of modules."""
memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield name, v

def named_parameters(self, prefix='', recurse=True):
gen = self._named_members(
lambda module: module._parameters.items(),
prefix=prefix, recurse=recurse)
for elem in gen:
yield elem

def named_buffers(self, prefix='', recurse=True):
gen = self._named_members(
lambda module: module._buffers.items(),
prefix=prefix, recurse=recurse)
for elem in gen:
yield elem

def named_children(self):
memo = set()
for name, module in self._modules.items():
if module is not None and module not in memo:
memo.add(module)
yield name, module

def named_modules(self, memo=None, prefix=''):
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in module.named_modules(memo, submodule_prefix):
yield m

至此,又介绍了  个函数。

函数train和函数eval的作用是将Module及其SubModule分别设置为training modeevaluation mode。这两个函数只对特定的Module有影响,例如Class DropoutClass BatchNorm

def train(self, mode=True):
self.training = mode
for module in self.children():
module.train(mode)
return self

def eval(self):
return self.train(False)

函数requires_grad_用于设置self.parameters()是否需要record梯度,默认情况下是True。函数zero_grad 用于设置self.parameters()gradients为零。

def requires_grad_(self, requires_grad=True):
for p in self.parameters():
p.requires_grad_(requires_grad)
return self

def zero_grad(self):
for p in self.parameters():
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()

函数_get_nameextra_repr__repr__以及__dir__都是用于输出Module的相关信息的。_get_name返回的是Module类的名字;extra_repr是用于torch.nn.Module的子类来具体实现,用于输出module的信息,可以输出一行或者多行的字符串信息,具体示例如下所示;__repr__用于输出该Module中所有SubModule的信息并且one item per line__dir__用于输出该Module中包含的所有self.__class__self.__dict__.keys()self._parameters.keys()self._modules.keys()以及self._buffers.keys(),并且会通过key for key in keys if not key[0].isdigit()来消除不合法的Python变量名称的属性。

def _get_name(self):
return self.__class__.__name__

def extra_repr(self):
return ''

def __repr__(self):
extra_lines = []
extra_repr = self.extra_repr()
if extra_repr: # empty string will be split into list ['']
extra_lines = extra_repr.split('\n')
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
lines = extra_lines + child_lines

main_str = self._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str

def __dir__(self):
module_attrs = dir(self.__class__)
attrs = list(self.__dict__.keys())
parameters = list(self._parameters.keys())
modules = list(self._modules.keys())
buffers = list(self._buffers.keys())
keys = module_attrs + attrs + parameters + modules + buffers

# Eliminate attrs that are not legal Python variable names
keys = [key for key in keys if not key[0].isdigit()]
return sorted(keys)

# --------------------------

# torch.nn.Linear -- class Linear(Module)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)

# Example
>>> l = torch.nn.Linear(2, 2)
>>> l.extra_repr()
'in_features=2, out_features=2, bias=True'

至此又介绍了  个函数。

__setstate__设置state,如果self.__dict__中找不到_forward_pre_hooks_state_dict_hooks_load_state_dict_pre_hooks,那么就在self中定义这三个变量为OrderedDict

def __setstate__(self, state):
self.__dict__.update(state)
# Support loading old checkpoints that don't have the following attrs:
if '_forward_pre_hooks' not in self.__dict__:
self._forward_pre_hooks = OrderedDict()
if '_state_dict_hooks' not in self.__dict__:
self._state_dict_hooks = OrderedDict()
if '_load_state_dict_pre_hooks' not in self.__dict__:
self._load_state_dict_pre_hooks = OrderedDict()

__getattr__用于获取给定nameModule类中的成员。首先从self.__dict__['_parameters']self.__dict__['_buffers']以及self.__dict__['_modules']中查找,找到后将其return;若找不到,则调用raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))报错。

def __getattr__(self, name):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
return _parameters[name]
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name]
if '_modules' in self.__dict__:
modules = self.__dict__['_modules']
if name in modules:
return modules[name]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))

__setattr__(self, name, value)用于设置属性,即首先从self.__dict__.get('_parameters')self.__dict__.get('_buffers')以及self.__dict__.get('_modules')中查找,找到后则将该Key-Value删除,按照给定的namevalue重新register

__delattr__用于删除给定nameModule类中的成员。首先从self._parametersself._buffers以及self._modules中查找,找到后使用del将其删除;若找不到,则调用object.__delattr__(self, name)进行删除。

def __delattr__(self, name):
if name in self._parameters:
del self._parameters[name]
elif name in self._buffers:
del self._buffers[name]
elif name in self._modules:
del self._modules[name]
else:
object.__delattr__(self, name)

_save_to_state_dict(self, destination, prefix, keep_vars)的作用是将module state储存到destination,并且只针对该module,所以这个函数一般是被module中的所有SubModule调用。This is called on every submodule in method ~ torch.nn.Module.state_dict_load_from_state_dict的作用与之相反,是用来加载module的,相同的是也只针对该module,所以这个函数通常是被module中的所有SubModule调用。This is called on every submodule in method ~ torch.nn.Module.load_state_dict。参数prefix表示的是该Moduleparametersbuffers的前缀。

def _save_to_state_dict(self, destination, prefix, keep_vars):
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.data
for name, buf in self._buffers.items():
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.data

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
local_state = {k: v.data for k, v in local_name_params if v is not None}

for name, param in local_state.items():
key = prefix + name
if key in state_dict:
# ...
elif strict:
missing_keys.append(key)

if strict:
for key in state_dict.keys():
if key.startswith(prefix):
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)

函数_register_state_dict_hook和函数_register_load_state_dict_pre_hook的作用也很简单,直接看代码注释就知道了。hooks即文件 torch.utils.hooks ,文件中的类Class RemovableHandle的作用是A handle which provides the capability to remove a hook

def _register_state_dict_hook(self, hook):
r"""These hooks will be called with arguments: `self`, `state_dict`,
`prefix`, `local_metadata`, after the `state_dict` of `self` is set.
Note that only parameters and buffers of `self` or its children are
guaranteed to exist in `state_dict`. The hooks may modify `state_dict`
inplace or return a new one.
"""
handle = hooks.RemovableHandle(self._state_dict_hooks)
self._state_dict_hooks[handle.id] = hook
return handle

def _register_load_state_dict_pre_hook(self, hook):
r"""These hooks will be called with arguments: `state_dict`, `prefix`,
`local_metadata`, `strict`, `missing_keys`, `unexpected_keys`,
`error_msgs`, before loading `state_dict` into `self`. These arguments
are exactly the same as those of `_load_from_state_dict`.
"""
handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
self._load_state_dict_pre_hooks[handle.id] = hook
return handle

函数_tracing_name主要是被函数_slow_forward调用。函数_slow_forward和函数__call__作用相似,都是在利用函数forward做计算。

def _tracing_name(self, tracing_state):
if not tracing_state._traced_module_stack:
return None
module = tracing_state._traced_module_stack[-1]
for name, child in module.named_children():
if child is self:
return name
return None

def _slow_forward(self, *input, **kwargs):
tracing_state = torch._C._get_tracing_state()
if not tracing_state:
return self.forward(*input, **kwargs)
if not hasattr(tracing_state, '_traced_module_stack'):
tracing_state._traced_module_stack = []
name = self._tracing_name(tracing_state)
if name:
tracing_state.push_scope('%s[%s]' % (self._get_name(), name))
else:
tracing_state.push_scope(self._get_name())
tracing_state._traced_module_stack.append(self)
try:
result = self.forward(*input, **kwargs)
finally:
tracing_state.pop_scope()
tracing_state._traced_module_stack.pop()
return result

def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result

register_parameterregister_bufferadd_module这三个函数可以放一起看。函数register_parameter的作用就是将给定的name - param加入到字典self._parameters中去。函数register_buffer通常用于register那些不属于model parameters的属性,例如BatchNormrunning_mean就不是parameter。函数add_module的作用是给当前Module按照传递进来的参数对name - module添加子模块SubModule

def register_parameter(self, name, param):
# Check AttributeError TypeError KeyError ...
if param is None:
self._parameters[name] = None
elif not isinstance(param, Parameter):
# raise TypeError ...
elif param.grad_fn:
# raise ValueError ... Cannot assign non-leaf Tensor to parameter
else:
self._parameters[name] = param

def register_buffer(self, name, tensor):
r"""Example:
>>> self.register_buffer('running_mean', torch.zeros(num_features))
"""
if '_buffers' not in self.__dict__:
# raise AttributeError ... Cannot assign buffer before Module.__init__() call
elif not isinstance(name, torch._six.string_classes):
# raise TypeError ... Buffer name should be a string
elif '.' in name:
# raise KeyError ... Buffer name can't contain '.'
elif name == '':
# raise KeyError ... Buffer name can't be empty string
elif hasattr(self, name) and name not in self._buffers:
# raise KeyError ... Attribute already exists
elif tensor is not None and not isinstance(tensor, torch.Tensor):
# raise TypeError
else:
self._buffers[name] = tensor

def add_module(self, name, module):
# Check TypeError KeyError ...
self._modules[name] = module

函数register_backward_hookregister_forward_pre_hook以及register_forward_hook的作用与前边介绍的函数_register_state_dict_hook和函数_register_load_state_dict_pre_hook的作用类似,也是在该Moduleregister一个hook。官方给出的说明是:【register_backward_hook will be called every time the gradients with respect to module inputs are computed. register_forward_pre_hook will be called every time before :func: forward is invoked. register_forward_hook will be called every time after :func: forward has computed an output.】。

def register_backward_hook(self, hook):
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle

def register_forward_pre_hook(self, hook):
handle = hooks.RemovableHandle(self._forward_pre_hooks)
self._forward_pre_hooks[handle.id] = hook
return handle

def register_forward_hook(self, hook):
handle = hooks.RemovableHandle(self._forward_hooks)
self._forward_hooks[handle.id] = hook
return handle

以上就是 torch.nn.Module 中的全部函数介绍(这个类的代码竟然长达一千多行)。


阅读这些源码,我最大的感触就是一定要有耐心读下去,坚持下去就可以。


推荐阅读:


科研大牛们怎么读文献?

一文搞懂RNN(循环神经网络)基础篇

模型训练太慢?显存不够用?这个算法让你的GPU老树开新花

    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存