查看原文
其他

KDD 2023 | MetricPrompt: 基于度量的提示学习少标注文本分类方法

董泓源,张伟男等 PaperWeekly
2024-08-22

©作者 | 董泓源,张伟男等
单位 | 哈尔滨工业大学
来源 | 社媒派SMP




论文介绍

尽管在越来越多地应用于各类少标注自然语言处理任务中,提示学习方法中模板和标签映射的设计十分困难,需要对模型、分类任务的深入理解和大量试错。现有的标签映射自动化设计方法尽管省去了人力,但是却难以将下游少标注文本分类任务和预训练任务统一起来,导致其性能仍不尽如人意。
针对这一问题,本文提出 MetricPrompt,通过将少标注文本分类任务转化为文本相关度预估任务,将人工劳动从标签映射设计中解放出来。MetricPrompt 使用提示学习模型作为相关度度量,良好地契合了预训练任务,并能够捕捉输入文本对之间的交互信息来获得更高的预测精度。
在三个少标注文本分类任务的四种少标注设定下,MetricPrompt 超越了之前最优的的自动化标签映射设计方法,并在无需人工进行任务相关标签映射设计的情况下取得了比人工设计更优秀的性能。

论文标题:

MetricPrompt: Prompting Model as a Relevance Metric for Few-shot Text Classification

论文作者:

董泓源,张伟男,车万翔

收录会议:

KDD 2023, Long Paper

论文地址:

https://dl.acm.org/doi/10.1145/3580305.3599430





简介
文本分类被视为文本挖掘中的最为基础和重要的任务之一,其相关技术被应用于各种文本挖掘应用场景,例如信息检索,情感分析,推荐系统,知识管理等 [1]。近年来备受研究者们关注的预训练语言模型在富标注文本分类任务上能够取得令人满意的文本分类性能,但这些模型的少标注学习能力仍然远远落后于人类智能 [2]。 
提示学习方法通过将下游任务与其预训练目标对齐来更好地利用预训练模型的通用知识。提示学习模型以提示文本为输入,通过标签映射将模型的输出词映射到相应的标签,得到文本分类结果。在这一过程中,标签映射的设计很大程度上决定着提示学习模型的性能。然而,设计一个合适的标签映射十分困难。为此,研究者们提出了自动标签映射设计方法来缓解人工标签映射设计的压力。 
这些算法可以分为离散标签映射设计和软标签映射设计方法。离散标签映射设计方法,如 AVS [3]、LM-BFF [4] 和 AutoPrompt [5],在预训练模型的词汇中搜索每个标签对应的答案词以构建标签映射。软标签映射设计方法,如 WARP [6] 和 ProtoVerb [7],在一个无限连续空间中搜索合适的标签映射参数,从而实现更好的性能。
然而,如图 1 所示,这两种方法都使用预训练模型的内部激活值作为样本的特征表示,通过计算其与各个标签特征表示的欧式距离进行分类预测。这迫使预训练模型适应与其预训练目标不同的任务组织形式。更糟糕的是,这些方法中分类标签的特征表示必须在下游任务中从头开始训练,这可能会导致严重的过拟合问题。

▲ 图1. 各类标签映射设计方法对比。图中“CE”表示交叉熵损失函数,“PCL”是原型对比学习损失函数[7]

为了解决上述问题,本文提出了 MetricPrompt这一方法通过将少标注文本分类任务重构为文本对相关度估计任务,减轻了任务相关标签映射设计上的人力成本。

如图 1 所示,在本文的方法中不再需要显式的任务相关标签映射设计。与预训练模型的预训练目标一致,MetricPrompt 只对预训练模型的输出词概率分布进行处理,从而平滑地适应于下游任务。同时,MetricPrompt 将文本对作为输入,因此在其相关度建模过程中可以使用样本文本间的交叉相关度信息来提升估计精度。 

本文在三个广泛使用的少标注文本分类数据集上进行了四种少标注设定下的实验,结果表明,MetricPrompt 超越了所有自动标签映射设计基线方法,甚至还超越了需要大量人力进行任务相关标签映射设计的人工设计方法。此外,本文对 MetricPrompt 的可扩展性和鲁棒性进行了分析实验,并解释了在使用不同相关度分数池化方法时,模型性能产生变化的原因。




方法

3.1 数据构建
给定一个少标注文本分类数据集 ,本文用 表示训练数据,用 表示测试样本集。一个样本表示为 , 其中 表示样本文本, 表示其标签。由于 MetricPrompt 接受一对样本文本作为输入,本文按如下方式构建训练数据:
其中 是 MetricPromt 的提示学习模板函数。该函数将两段样本文本填入提示学习模板中以生成提示学习模型的输入。本文使用 " Anewsof topic: " 作为提示学习模板 示输入的样本文本对是否属于同一类别。类似地,本文以如下方式构建 MetricPrompt 的测试数据:
整个数据构建过程如图 2 所示:
▲ 图2. MetricPrompt的数据构造和训练过程

3.2 优化

为由参数 参数化的 MLM 模型, 为其在 位置的输出词概率分布。本文将 MetricPrompt 的优化目标定义如下:
其中 表示标签类别上的概率分布。输入样本的真实标签对应位置设为 1,而其他位置设为 0。 表示一个预定义的任务通用的标签映射,它将输出词汇概率分布 映射到二项分布 上。
该标签映射将 relevant, similar, consistent 的 logits 聚合为标签 1 的预测 logit,同时将 irrelevant, inconsistent, different 的 logits 聚合为标签 0 的 logit。这个标签映射可用于所有少标注文本分类任务,因此 MetricPrompt 不需要进行任何任务相关标签映射设计。 是 MetricPrompt 的损失函数,其形式为标签映射 生成的概率分布与真实分布之间的交叉嫡损失。

3.3 推理

▲ 图3. MetricPrompt的推理过程

经过优化后,提示学习模型在推理过程中充当相关度度量。如图 3 所示,本文将原始测试样本 (以黑色表示)与所有不同类别的训练样本(彩色)配对,形成推理阶段的样本。给定一个原始训练样本 ,MetricPrompt 根据以下方式计算其 与的相关度得分
在这里, 用于计算二项概率分布在 1 和 0 位置概率之间的差值。本文用 表示少标注文本分类任务的标签。令:
MetricPrompt 通过汇集该标签相应样本与 的相关度,来计算标签 的分类得
最后,MetricPrompt 选择具有最高相关度得分的标签 作为分类结果:
上文阐述了使用求和池化的 MetricPrompt 工作方式。这个池化函数也可以用最大池化和 K-最近邻 (KNN) 池化来替换。最大池化将测试样本 归类为与其最相关的训练样本所对应的类别。将 的计算方式替换为以下公式,MetricPrompt 可以采用最大池化方法处理文本分类任务:
对于 KNN 池化,本文用 表示在 中与 最相关的 个训练样本,并将 的计算方式重写为:
本文将 设为训练集大小的一半。
当多个标签出现相同的次数时,本文选择这些标签对应样本中获得最高相关度分数的样本,并将 分类到其类别。

3.4 更高效的推理

为了进一步提升 MetricPrompt 的效率,本文提出使用代表性样本的方法,来减小 MetricPrompt 推理阶段的开销。对于一个标签为 的训练样本 ,使用 表示该样本的代表性,其计算方式如下:
其中, 表示样本 的相关度分数。基于该代表性指标,本文从每个标签所对应的训练样本中,选取获得最高代表性分数的 个样本参与推理过程。

代表性样本可以大大减小 MetricPrompt 推理过程的时间复杂度。对于一个标签数量为 n,每个标签对应 k 个样本的少标注文本分类任务,在不引入代表性样本时,每一个测试样本需要和 n*k 个训练样本配对并进行相关度分数计算,该过程的时间复杂度为 O(n*k)。

作为对比,传统的提示学习方法和其他不需要人工标签映射设计的提示学习方法则只需要将预训练模型抽取得到的测试样本特征表示与各个标签的特征表示进行点积相似度计算,时间复杂度仅为 O(n)。

在引入代表性样本对推理过程进行优化后,MetricPrompt 仅需计算每个测试样本与各标签下的代表性样本进行相关度预估,因此时间复杂度减少为 O(p*n)。其中 p 是人为设定的一个常数。因此, MetricPrompt 在使用代表性样本进行推理加速后,时间复杂度为 O(n),与其他常用提示学习方法一致。实验中,本文将各标签代表性样本数量 p 设置为 2。





实验

4.1 数据集

本文采用 AG's News、Yahoo Answers Topics 和 DBPedia 三个文本分类数据集进行实验。数据集的统计数据在表 1 中给出:

▲ 表1. 数据集统计信息


4.2 实现细节

本文在 2, 4, 8 和 16-shot 设置下进行实验,其中相应数量的训练样本从每个数据集的训练集中随机抽样。本文为每个数据集和每个少标注设定抽取 10 个训练集,以减轻训练集选择中随机性的影响。所有实验结果均以模型在 10 个训练集上的性能的平均值给出。 

为了公平比较,本文将 BERT-base-uncased 作为 MetricPrompt 和所有基线模型的主干模型。本文根据训练集的大小设置总训练步数,并相应地调整训练轮次数。训练集的大小因数据集标签数量和少标注设定而异,各个设定下的具体训练轮次数参见表 2。

▲ 表2. 不同实验设定下的训练轮次数

与 ProtoVerb 相比,MetricPrompt 在使用平均池化和最大池化的情况下,性能下降较少,达到了更高的分类精度。

4.3 主要实验结果

本文在四种少标注设定下对三个具有不同文本风格的文本分类数据集进行实验。2、4-shot 实验结果列在表 3 中,而 8、16-shot 的实验结果在表 4 中列出。

▲ 表3. 2-shot和4-shot设定下的实验结果,实验结果以准确率作为指标。斜体表示该方法需要人工进行任务相关标签映射设计,粗体表示在无需人工任务相关标签映射设计方法中的最佳结果


▲ 表4. 8-shot和16-shot设定下的实验结果,实验结果以准确率作为指标。斜体表示该方法需要人工进行任务相关标签映射设计,粗体表示在无需人工任务相关标签映射设计方法中的最佳结果


与无需人工标签设计的 SOTA 提示学习方法 ProtoVerb 相比,MetricPrompt 在 2-shot 准确率上提高了 5.88,4-shot 准确率上提高了 11.92,8-shot 准确率上提高了 6.80,16-shot 准确率上提高了 1.56。

MetricPrompt 甚至在无需人工任务相关的标签映射设计的情况下,在所有少标注设定中超过了 ManualVerb 的表现。在每个标签仅选用2个代表性样本的实验设定下,MetricPrompt 仍取得了优秀的性能表现。在相同的时间复杂度下,性能大幅超越之前的 SOTA 基线模型 ProtoVerb,并取得了与 ManualVerb 相当的分数。



分析

5.1 使用领域外数据进行可扩展性测试

在实际应用中,无法总是在真实场景中获得足够的标注数据。因此,少标注文本分类方法在领域外(OOD)数据上的可扩展性对其实用性至关重要。为了证明 MetricPrompt 的可扩展性,本文利用各个数据集 16-shot 的训练集来辅助其他数据集的少标注文本分类任务。

▲ 表5. MetricPrompt和ProtoVerb在引入额外OOD数据情况下的模型性能

如表 5 所示,MetricPrompt 在 OOD 训练数据的加持下,获得了更高的准确率。与先前的 SOTA 基线 ProtoVerb 相比,MetricPrompt 在 18 个少标注和 OOD 数据设定下的 17 个中获得了更高的预测准确率(表中下划线数字)。

值得注意的是,MetricPrompt 在 1-shot 设置下的性能提升显著高于其他少标注设定。这是因为在 1-shot 设置下,MetricPrompt 只采用两个完全相同的文本作为正样本,导致过拟合问题严重。引入多样化的 OOD 数据有效地缓解了过拟合问题,因此大幅提高了 MetricPrompt 在 1-shot 任务中的性能。

5.2 对抗噪声的鲁棒性

由于缺乏监督信号,噪声样本会严重影响少标注文本分类模型的性能。本节对 MetricPrompt 在 AG's News 数据集上对抗噪声样本的鲁棒性进行评估。本文随机替换 1、2 和 4 个训练样本的标签引入噪声,并测试当 MetricPrompt 使用平均、最大和 KNN 池化时的性能。噪声样本引起的性能下降如表 6 所示:

▲ 表6. 模型在AG's News数据集的8、16-shot设定下,分别引入1、2和4个噪声样本时的性能下降。粗体表示所有方法中最少的性能下降量


5.3 不同池化方法的比较

首先,本文对不含噪声样本的场景进行分析。本文收集了 MetricPrompt 计算的相关度分数分布的统计信息。如图 4 所示,相关度分数的分布是高度不均匀的。因此,最大相关度分数在使用平均池化的 MetricPrompt 中起到了决定性的作用,导致了与最大池化类似的行为。然而,KNN 池化采用投票策略,忽略了分数值信息,带来了个更多的分类错误。

▲ 图4. AG's News数据集2-shot设定下,各个测试样本与训练样本之间的平均相关度分数

接下来分析 MetricPrompt 在存在噪声样本的情况下的性能表现。如图 4 所示,除了前几个最相关的样本,其余相关度分数的分布相对均匀。假设相关度分数的分布是均匀的极端情况,KNN 池化的预测结果将受到每个类别训练样本数量方差的显著影响。

基于这一现象,本文将引入噪声样本时 KNN 池化的性能较差归咎于其投票机制使其容易受到各个类别训练样本数量方差的影响。为了验证这一点,本文对收集每个类型类别的平均预测测试样本数量进行了统计。

▲ 图5. AG's News数据集8-shot设定下,包含7、8和9个训练样本的类别在测试阶段平均作为预测目标的次数。“\# Predicted query sample”表示测试阶段被预测为该类别的测试样本平均个数

如图 5 所示,KNN 池化对拥有更多训练样本的类别表现出更强的偏好,显著高于平均池化和最大池化,导致其对应的预测测试样本数量异常地高。因此,在引入噪声样本时,KNN 池化的性能大幅下降。

5.4 代表性样本数量影响分析

本节研究代表性样本数量对 MetricPrompt 性能的影响。本文在三个数据集的四种少标注设定下进行实验,并将代表性样本数量分别设为1,2 和 4。

▲ 表7. 2-shot和4-shot设定下使用代表性样本的实验结果,实验结果以准确率作为指标。粗体表示该任务上的最佳结果


▲ 表8. 8-shot和16-shot设定下使用代表性样本的实验结果,实验结果以准确率作为指标。粗体表示该任务上的最佳结果


如表 7 和表 8 所示,MetricPrompt 的性能与代表性样本数量正相关。值得注意的是,即使仅为每个类别保留一个代表性样本参与推理,MetricPrompt 在四个少标注设置下的分类精度也仍然优于先前的 SOTA 方法 ProtoVerb。通过调整代表性样本数量 p,可以使 MetricPrompt 实现分类准确率和效率之间的平衡。



总结


针对基于提示学习的少标注文本分类方法性能严重依赖人工标签映射设计,而自动化标签映射设计方法性能较差的问题,本文提出了 MetricPrompt,通过将少标注文本分类任务转化成文本对相关度预估任务来减轻人工标签映射设计的负担。MetricPrompt 将少标注训练数据两两配对,并训练提示学习模型对文本对相关度进行估计。优化后的提示学习模型作为一个文本相关度度量来估计测试样本与各训练样本之间的相关度,从而完成分类预测。

相较于其他自动标签映射设计方法,MetricPrompt 无需引入任务特定的标签特征表示,避免了下游任务中标注数据过少引发的过拟合问题。同时,MetricPrompt 的工作方式可以视为一种广义掩码语言建模任务,使得预训练模型能够更顺利地适配于下游少标注文本分类任务。

三个数据集上四种少标注设定下的实验结果表明,MetricPrompt 性能显著优于之前的 SOTA 模型,且在未引入人工知识进行任务相关标签映射设计的情况下,取得了比人工设计方法更好的文本分类性能。

参考文献

[1] Vandana Korde and C Namrata Mahender. 2012. TEXT CLASSIFICATION AND CLASSIFIERS: A SURVEY. International Journal of Artificial Intelligence & Applications 3, 2 (2012), 85.

[2] Tom B. Brown, Benjamin Mann, et al. 2020. Language Models are Few-Shot Learners. NeurIPS 2020.

[3] Timo Schick and Hinrich Schütze. 2021. Exploiting Cloze-Questions for FewShot Text Classification and Natural Language Inference. EACL 2021

[4] Tianyu Gao, Adam Fisch, and Danqi Chen. 2021. Making Pre-trained Language Models Better Few-shot Learners. ACL 2021.

[5] Taylor Shin, Yasaman Razeghi, Robert L. Logan IV, Eric Wallace, and Sameer Singh. 2020. AutoPrompt: Eliciting Knowledge from Language Models with Automatically Generated Prompts. EMNLP 2020.

[6] Karen Hambardzumyan, Hrant Khachatrian, and Jonathan May. 2021. WARP: Word-level Adversarial ReProgramming. ACL 2021.

[7] Ganqu Cui, Shengding Hu, Ning Ding, Longtao Huang, and Zhiyuan Liu. 2022. Prototypical Verbalizer for Prompt-based Few-shot Tuning. ACL 2022.


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·
·

继续滑动看下一个
PaperWeekly
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存