其他
【他山之石】如何支撑上亿类别的人脸训练?显存均衡的模型并行(PyTorch实现)
18年的工作,一直没抽出时间整理出来,模型并行看似神秘,在网上搜索相关资料的时候大部也是以谈原理的居多,唯独少了有人拿出代码来捅破这层窗户纸。这里我放出一个PyTorch版本的Demo供大家参考交流。
Demo地址:https://github.com/bindog/pytorch-model-parallel
地址:http://bindog.github.io/
01
什么是模型并行?
02
朴素的模型并行
class FullyConnected(nn.Module):
def __init__(self, in_dim, out_dim, num_gpu, model_parallel=False):
super(FullyConnected, self).__init__()
self.num_gpu = num_gpu
self.model_parallel = model_parallel
if model_parallel:
self.fc_chunks = nn.ModuleList()
for i in range(num_gpu):
_class_num = out_dim // num_gpu
if i < (out_dim % num_gpu):
_class_num += 1
self.fc_chunks.append(
nn.Linear(in_dim,
_class_num,
bias=False).cuda(i)
)
else:
self.classifier = nn.Linear(
in_dim,
out_dim,
bias=False)
def forward(self, x):
if self.model_parallel:
x_list = []
for i in range(self.num_gpu):
_x = self.fc_chunks[i](x.cuda(i)) # 分别在不同的卡上计算
x_list.append(_x)
x = torch.cat(x_list, dim=1) # 把结果concat起来
return x
else:
return self.classifier(x)
03
显存均衡的模型并行
def stable_softmax(x):
z = x - max(x)
numerator = np.exp(z)
denominator = np.sum(numerator)
softmax = numerator / denominator
return softmax
class ModelParallelCrossEntropy(nn.Module):
def __init__(self):
super(ModelParallelCrossEntropy, self).__init__()
# args[0] is compute loss flag, args[1] is label_tuple
# args[2:] is logit parts
def forward(self, *args):
return ModelParallelCrossEntropyFunc(args[0], args[1])(*args[2:])
class ModelParallelCrossEntropyFunc(Function):
def __init__(self, compute_loss, label_tuple):
self.batch_size = label_tuple[0].size()[0]
self.compute_loss = compute_loss
self.label_split = label_tuple
def forward(self, *args): # args is list of logit parts
# for numerical stability
max_list = []
for arg in args:
m, _ = torch.max(arg, dim=1, keepdim=True)
max_list.append(m)
mc = torch.cat(max_list, dim=1)
m, _ = torch.max(mc, dim=1, keepdim=True)
nargs = [arg - m.to(gpu_id) for gpu_id, arg in enumerate(args)]
# get exp sum
exp_logit_list = []
exp_sum_list = []
for gpu_id, narg in enumerate(nargs):
exp_logit = torch.exp(narg)
exp_logit_list.append(exp_logit)
exp_sum = torch.sum(exp_logit, dim=1, keepdim=True)
exp_sum_list.append(exp_sum)
exp_sum_all = comm.reduce_add(exp_sum_list, 0)
# compute softmax output
softmax_list = []
for gpu_id, narg in enumerate(nargs):
softmax = exp_logit_list[gpu_id] / exp_sum_all.to(gpu_id)
softmax_list.append(softmax)
# save the softmax output, we will need it in backward
self.save_for_backward(*softmax_list)
loss = torch.zeros(1)
if self.compute_loss:
_loss_list = []
for gpu_id, softmax in enumerate(softmax_list):
_loss = torch.sum(softmax * self.label_split[gpu_id], dim=1)
_loss_list.append(_loss)
_loss = comm.reduce_add(_loss_list, 0)
log_loss = -torch.log(_loss)
loss = torch.mean(log_loss)
return loss
def backward(self, loss_grad):
grad_logit_list = []
for gpu_id, softmax in enumerate(self.saved_variables):
grad_logit = (softmax - self.label_split[gpu_id]) / self.batch_size
grad_logit_list.append(grad_logit)
return tuple(grad_logit_list)
04
和其他魔改loss相结合
class FullyConnected_AM(nn.Module):
def __init__(self, in_dim, out_dim, num_gpus=1, model_parallel=False, class_split=None, margin=0.35, scale=30):
super(FullyConnected_AM, self).__init__()
self.num_gpus = num_gpus
self.model_parallel = model_parallel
if self.model_parallel:
self.am_branches = nn.ModuleList()
for i in range(num_gpus):
self.am_branches.append(AM_Branch(in_dim, class_split[i], margin, scale).cuda(i))
else:
self.am = AM_Branch(in_dim, out_dim, margin, scale)
# 非模型并行情况下,labels为one-hot形式
# 模型并行情况下,labels是一个list,里面每个元素是labels的one-hot形式的一小部分(按class方向划分、且已经在对应的显卡上),concat起来就是完整的one-hot形式label
def forward(self, x, labels=None):
if self.model_parallel:
output_list = []
for i in range(self.num_gpus):
output = self.am_branches[i](x.cuda(i), labels[i])
output_list.append(output)
return tuple(output_list)
else:
return self.am(x, labels)
class AM_Branch(nn.Module):
def __init__(self, in_dim, out_dim, margin=0.35, scale=30):
super(AM_Branch, self).__init__()
self.m = margin
self.s = scale
# training parameter
self.weight = nn.Parameter(torch.Tensor(in_dim, out_dim), requires_grad=True)
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
# 这里的label必须是已经转换为one-hot形式
# 如果是模型并行下,label是one-hot形式的一部分,且位于对应的显卡上
def forward(self, x, label):
x_norm = x.pow(2).sum(1).pow(0.5)
w_norm = self.weight.pow(2).sum(0).pow(0.5)
cos_theta = torch.mm(x, self.weight) / x_norm.view(-1, 1) / w_norm.view(1, -1)
cos_theta = cos_theta.clamp(-1, 1)
phi = cos_theta - self.m
index = label.data
index = index.byte()
output = cos_theta * 1.0
output[index] = phi[index]
output *= self.s
return output
05
说明
python train.py --gpus=0,1,2,3 --data_path=/your/data/path --num_classes=3000000 --am --model_parallel
历史文章推荐
太牛逼了!一位中国博士把整个CNN都给可视化了,每个细节看的清清楚楚!
Nature发表牛津博士建议:我希望在读博士之初时就能知道的20件事
沈向洋、华刚:读科研论文的三个层次、四个阶段与十个问题
如何看待2021年秋招算法岗灰飞烟灭?
独家解读 | ExprGAN:基于强度可控的表情编辑
独家解读 | 矩阵视角下的BP算法
独家解读 | Capsule Network深度解读
独家解读 | Fisher信息度量下的对抗攻击
论文解读 | 知识图谱最新研究综述
你的毕业论文过了吗?《如何撰写毕业论文?》
卡尔曼滤波系列——经典卡尔曼滤波推导
一代传奇 SIFT 算法 专利到期!
人体姿态估计的过去,现在,未来
给研究新生的建议,光看论文是学不好的,一定要看书,看书,看书!
分享、点赞、在看,给个三连击呗!