查看原文
其他

Huggingface BERT源码详解:应用模型与训练优化

李泺秋 PaperWeekly 2022-07-04


©PaperWeekly 原创 · 作者|李泺秋

学校|浙江大学硕士生

研究方向|自然语言处理、知识图谱


上篇,记录一下对 HuggingFace 开源的 Transformers 项目代码的理解。

本文基于 Transformers 版本 4.4.2(2021 年 3 月 19 日发布)项目中,pytorch 版的 BERT 相关代码,从代码结构、具体实现与原理,以及使用的角度进行分析,包含以下内容:

1. BERT Tokenization 分词模型(BertTokenizer)

2. BERT Model 本体模型(BertModel)

3. 1. BertEmbeddings

    2. BertEncoder

    3.1. BertLayer

        2.1. BertAttention

            2.1. BertIntermediate

               2. BertOutput

            3. BertEmbeddings

            4. BertEncoder

        3. BERT-based Models应用模型

4. BertForPreTraining

5. 1. BertForSequenceClassification

    2. BertForMultiChoice

    3. BertForTokenClassification

    4. BertForQuestionAnswering

    5. BERT训练与优化

6. BERT训练与优化

7. 1. Pre-Training

    2. Fine-Tuning

    3. 1. AdamW

        2. Warmup




BERT-based Models

基于 BERT 的模型都写在/models/bert/modeling_bert.py里面,包括 BERT 预训练模型和 BERT 分类模型,UML 图如下:
BERT模型一图流(建议保存后放大查看):

▲ 画图工具:Pyreverse

首先,以下所有的模型都是基于BertPreTrainedModel这一抽象基类的,而后者则基于一个更大的基类PreTrainedModel。这里我们关注BertPreTrainedModel的功能:

  • 用于初始化模型权重,同时维护继承自PreTrainedModel的一些标记身份或者加载模型时的类变量。
下面,首先从预训练模型开始分析。

3.1 BertForPreTraining

众所周知,BERT 预训练任务包括两个:

  • Masked Language Model(MLM):在句子中随机用[MASK]替换一部分单词,然后将句子传入 BERT 中编码每一个单词的信息,最终用[MASK]的编码信息预测该位置的正确单词,这一任务旨在训练模型根据上下文理解单词的意思;
  • Next Sentence Prediction(NSP):将句子对 A 和 B 输入 BERT,使用[CLS]的编码信息进行预测 B 是否 A 的下一句,这一任务旨在训练模型理解预测句子间的关系。


▲ 图源网络

而对应到代码中,这一融合两个任务的模型就是BertForPreTraining,其中包含两个组件:

class BertForPreTraining(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config)
        self.cls = BertPreTrainingHeads(config)

        self.init_weights()
    # ...
这里的BertModel上一篇文章中已经详细介绍了(注意,这里设置的是默认add_pooling_layer=True,即会提取[CLS]对应的输出用于 NSP 任务),BertPreTrainingHeads则是负责两个任务的预测模块:
class BertPreTrainingHeads(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = BertLMPredictionHead(config)
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score 
又是一层封装:BertPreTrainingHeads包裹了BertLMPredictionHead 和一个代表 NSP 任务的线性层。这里不把 NSP 对应的任务也封装一个BertXXXPredictionHead,估计是因为它太简单了,没有必要……
补充:其实是有封装这个类的,不过它叫做BertOnlyNSPHead,在这里用不上……
继续下探BertPreTrainingHeads
class BertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states

这个类用于预测[MASK]位置的输出在每个词作为类别的分类输出,注意到:

  • 该类重新初始化了一个全 0 向量作为预测权重的 bias;

  • 该类的输出形状为[batch_size, seq_length, vocab_size],即预测每个句子每个词是什么类别的概率值(注意这里没有做 softmax);
  • 又一个封装的类:BertPredictionHeadTransform,用来完成一些线性变换:


class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


补充:感觉这一层去掉也行?输出的形状也没有发生变化。我个人的理解是和 Pooling 那里做一个对称的操作,同样过一层 dense 再接分类器……

回到BertForPreTraining,继续看两块 loss 是怎么处理的。它的前向传播和BertModel的有所不同,多了labelsnext_sentence_label 两个输入:

  • labels:形状为[batch_size, seq_length] ,代表 MLM 任务的标签,注意这里对于原本未被遮盖的词设置为 -100,被遮盖词才会有它们对应的 id,和任务设置是反过来的

    • 例如,原始句子是I want to [MASK] an apple,这里我把单词eat给遮住了输入模型,对应的label设置为[-100, -100, -100, 【eat对应的id】, -100, -100]
    • 为什么要设置为 -100 而不是其他数?因为torch.nn.CrossEntropyLoss默认的ignore_index=-100,也就是说对于标签为 100 的类别输入不会计算 loss。

  • next_sentence_label:这一个输入很简单,就是 0 和 1 的二分类标签。

 # ...
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        next_sentence_label=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
 ...
OK,接下来两部分 loss 的组合:
   # ...
        total_loss = None
        if labels is not None and next_sentence_label is not None:
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
            next_sentence_loss = loss_fct(seq_relationship_score.view(-12), next_sentence_label.view(-1))
            total_loss = masked_lm_loss + next_sentence_loss
        # ...
直接相加,就是这么单纯的策略。
当然,这份代码里面也包含了对于只想对单个目标进行预训练的 BERT 模型(具体细节不作展开):
  • BertForMaskedLM:只进行 MLM 任务的预训练;
    • 基于BertOnlyMLMHead,而后者也是对BertLMPredictionHead的另一层封装;
  • BertLMHeadModel:这个和上一个的区别在于,这一模型是作为 decoder 运行的版本;
    • 同样基于BertOnlyMLMHead
  • BertForNextSentencePrediction:只进行 NSP 任务的预训练。
    • 基于BertOnlyNSPHead,内容就是一个线性层……
接下来介绍的是各种 Fine-tune 模型,基本都是分类任务:
▲ 图源:原始BERT论文附录
3.2 BertForSequenceClassification
这一模型用于句子分类(也可以是回归)任务,比如 GLUE benchmark 的各个任务。
  • 句子分类的输入为句子(对),输出为单个分类标签。

结构上很简单,就是BertModel(有 pooling)过一个 dropout 后接一个线性层输出分类:
 class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()
        # ...

在前向传播时,和上面预训练模型一样需要传入labels输入。

  • 如果初始化的num_labels=1,那么就默认为回归任务,使用 MSELoss;

  • 否则认为是分类任务。

3.3 BertForMultipleChoice

这一模型用于多项选择,如 RocStories/SWAG 任务。
  • 多项选择任务的输入为一组分次输入的句子,输出为选择某一句子的单个标签。

结构上与句子分类相似,只不过线性层输出维度为 1,即每次需要将每个样本的多个句子的输出拼接起来作为每个样本的预测分数。

  • 实际上,具体操作时是把每个 batch 的多个句子一同放入的,所以一次处理的输入为[batch_size, num_choices]数量的句子,因此相同 batch 大小时,比句子分类等任务需要更多的显存,在训练时需要小心。

3.4 BertForTokenClassification

这一模型用于序列标注(词分类),如 NER 任务。
  • 序列标注任务的输入为单个句子文本,输出为每个 token 对应的类别标签。
由于需要用到每个 token对应的输出而不只是某几个,所以这里的BertModel不用加入 pooling 层;
  • 同时,这里将_keys_to_ignore_on_load_unexpected这一个类参数设置为[r"pooler"],也就是在加载模型时对于出现不需要的权重不发生报错。

3.5 BertForQuestionAnswering

这一模型用于解决问答任务,例如 SQuAD 任务。

  • 问答任务的输入为问题 +(对于 BERT 只能是一个)回答组成的句子对,输出为起始位置和结束位置用于标出回答中的具体文本。

这里需要两个输出,即对起始位置的预测和对结束位置的预测,两个输出的长度都和句子长度一样,从其中挑出最大的预测值对应的下标作为预测的位置。

  • 对超出句子长度的非法 label,会将其压缩(torch.clamp_)到合理范围。


作为一个迟到的补充,这里稍微介绍一下ModelOutput这个类。它作为上述各个模型输出包装的基类,同时支持字典式的存取和下标顺序的访问,继承自python原生的OrderedDict 类。
以上就是关于 BERT 源码的介绍,下面介绍一些关于 BERT 模型实用的训练细节。



BERT训练和优化

4.1 Pre-Training

预训练阶段,除了众所周知的 15%、80% mask 比例,有一个值得注意的地方就是参数共享。

不止 BERT,所有 huggingface 实现的 PLM 的 word embedding 和 masked language model 的预测权重在初始化过程中都是共享的:
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
    # ...
    def tie_weights(self):
        """
        Tie the weights between the input embeddings and the output embeddings.

        If the :obj:`torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
        the weights instead.
        """

        output_embeddings = self.get_output_embeddings()
        if output_embeddings is not None and self.config.tie_word_embeddings:
            self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())

        if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
            if hasattr(self, self.base_model_prefix):
                self = getattr(self, self.base_model_prefix)
            self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
    # ...

至于为什么,应该是因为 word_embedding 和 prediction 权重太大了,以 bert-base 为例,其尺寸为(30522, 768),降低训练难度。

4.2 Fine-Tuning

微调也就是下游任务阶段,也有两个值得注意的地方。

4.2.1 AdamW

首先介绍一下 BERT 的优化器:AdamW(AdamWeightDecayOptimizer)。

这一优化器来自 ICLR 2017 的 Best Paper:《Fixing Weight Decay Regularization in Adam》中提出的一种用于修复 Adam 的权重衰减错误的新方法。论文指出,L2 正则化和权重衰减在大部分情况下并不等价,只在 SGD 优化的情况下是等价的;而大多数框架中对于 Adam+L2 正则使用的是权重衰减的方式,两者不能混为一谈。

AdamW 是在 Adam+L2 正则化的基础上进行改进的算法,与一般的 Adam+L2 的区别如下:

关于 AdamW 的分析可以参考:

  • AdamW and Super-convergence is now the fastest way to train neural nets [1]

  • paperplanet:都 9102 年了,别再用 Adam + L2 regularization了 [2]

  • ICLR 2018 有什么值得关注的亮点?[3]
话说,《STABLE WEIGHT DECAY REGULARIZATION》这篇好像吐槽AdamW 的 Weight Decay 实现还是有问题…… 有空整整优化器相关的内容。
通常,我们会选择模型的 weight 部分参与 decay 过程,而另一部分(包括 LayerNorm 的 weight)不参与(代码最初来源应该是 Huggingface 的示例):
补充:关于这么做的理由,我暂时没有找到合理的解答,但是找到了一些相关的讨论https://forums.fast.ai/t/is-weight-decay-applied-to-the-bias-term/73212/4forums.fast.ai


# model: a Bert-based-model object
    # learning_rate: default 2e-5 for text classification
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias''LayerNorm.bias''LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(
            nd in n for nd in no_decay)], 'weight_decay'0.01},
        {'params': [p for n, p in param_optimizer if any(
            nd in n for nd in no_decay)], 'weight_decay'0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=learning_rate)
    # ...

4.2.2 Warmup

BERT 的训练中另一个特点在于 Warmup,其含义为:

  • 在训练初期使用较小的学习率(从 0 开始),在一定步数(比如 1000 步)内逐渐提高到正常大小(比如上面的 2e-5),避免模型过早进入局部最优而过拟合;

  • 在训练后期再慢慢将学习率降低到 0,避免后期训练还出现较大的参数变化。

在 Huggingface 的实现中,可以使用多种 warmup 策略:

TYPE_TO_SCHEDULER_FUNCTION = {
    SchedulerType.LINEAR: get_linear_schedule_with_warmup,
    SchedulerType.COSINE: get_cosine_schedule_with_warmup,
    SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
    SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
    SchedulerType.CONSTANT: get_constant_schedule,
    SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
}

具体而言:

  • CONSTANT:保持固定学习率不变;
  • CONSTANT_WITH_WARMUP:在每一个 step 中线性调整学习率;
  • LINEAR:上文提到的两段式调整;
  • COSINE:和两段式调整类似,只不过采用的是三角函数式的曲线调整;
  • COSINE_WITH_RESTARTS:训练中将上面 COSINE 的调整重复 n 次;
  • POLYNOMIAL:按指数曲线进行两段式调整。
具体使用参考transformers/optimization.py
  • 最常用的还是get_linear_scheduler_with_warmup即线性两段式调整学习率的方案……

def get_scheduler(
    name: Union[str, SchedulerType],
    optimizer: Optimizer,
    num_warmup_steps: Optional[int] = None,
    num_training_steps: Optional[int] = None,
):
 ...

以上即为关于 transformers 库(4.4.2 版本)中 BERT 相关代码的具体实现分析,欢迎与读者共同交流探讨。


参考文献

[1] https://www.fast.ai/2018/07/02/adam-weight-decay/
[2] https://zhuanlan.zhihu.com/p/63982470
[3] https://www.zhihu.com/question/67335251/answer/262989932


特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。



更多阅读




#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编





🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



关于PaperWeekly


PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。



您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存