查看原文
其他

复旦 DISC 原创 | 开放域对话中粗粒度响应选择的上下文细到粗蒸馏

陈伟 复旦DISC
2024-11-04

引言

本文介绍了复旦大学数据智能与社会计算实验室 (Fudan DISC) 在 ACL 2022 上录用的一篇长文:Contextual Fine-to-Coarse Distillation for Coarse-grained Response Selection in Open-Domain Conversations
论文地址:https://arxiv.org/abs/2109.13087

文章摘要


在本文中,我们提出了一个上下文细到粗蒸馏模型(CFC),用于开放域对话中的粗粒度响应选择。在 CFC 模型中,我们基于使用上下文匹配的多塔架构学习查询、候选上下文和响应的密集表示,并将单塔架构(细粒度)中学习的知识提炼到多塔架构中(粗粒度)以增强检索器的性能。为了评估所提出模型的性能,我们基于 Reddit 评论和 Twitter 语料库构建了两个新数据集。在两个数据集上的大量实验结果表明,与传统的基线方法相比,所提出的方法在所有评估指标上都得到了提升。

论文细节

1 研究背景

检索式对话系统通常包含两个阶段:粗粒度响应选择细粒度响应选择。如下图所示,在粗粒度阶段,由检索器从大规模候选数据库中识别出一个小得多的候选列表,然后在细粒度阶段中从检索到的候选列表中选择最佳响应。最近的研究更关注细粒度响应选择,并且提出了各种复杂模型来计算查询和候选响应之间的相似性。尽管已经报道了有希望的改进,但细粒度阶段的性能不可避免地受到构建的候选列表质量的限制。因此,高质量的粗粒度响应选择模块至关重要,这在现有文献中探索较少。
在本文中,我们专注于粗粒度响应选择的任务。该任务主要存在两个主要挑战:首先,和一般的文本匹配任务不同,对话中的上下文和响应之间重叠的关键字可能很少见,这使得很难将查询与候选响应直接匹配;其次,与细粒度响应选择相比,粗粒度响应选择处理的候选者数量要多得多。因此,应用复杂匹配模型来联合处理查询和响应以进行相似性计算是不切实际的。相反,基于稀疏表示的 BM25 系统是粗粒度文本匹配中的主流算法。
为了缓解上述两个问题,我们提出了一个用于粗粒度响应选择的上下文细到粗蒸馏模型。我们提出上下文匹配,即给定一个查询,将其与候选上下文进行匹配以找到最相似的上下文,并将相应的响应作为检索结果返回。在这种情况下,可以利用上下文中潜在的更丰富的关键字。为了利用复杂模型的优势并保持计算成本可接受,我们在保持原始结构不变的基础上,将从细粒度响应选择中学到的知识提炼成粗粒度响应选择中。我们基于 Reddit 评论数据和 Twitter 语料库构建了两个数据集。大量的实验结果表明,我们提出的模型在两个数据集上显著提高了检索召回率以及检索到的响应的困惑度和相关性。

2 方法

2.1 上下文匹配

为了将查询与候选上下文匹配,我们考虑了三种匹配方式:(i)查询 - 上下文匹配(QC matching),其中上下文而不是响应被视为候选文档;(ii) 查询 - 会话匹配 (QS matching),其中会话被视为候选文档,并返回最相似会话对应的响应(会话指的是上下文和响应的拼接文本);(iii) 解耦查询 - 会话匹配 (DQS matching),其中首先计算查询和上下文之间的相似度,然后计算查询和响应之间的相似度,然后计算这两者的加权和作为查询会话相似度。

2.2 多塔架构

为了以低延迟搜索大规模候选者,基于神经的检索器通常设计为(或限于)下图中的多塔式架构。在多塔模型中,查询和候选被不同的编码器独立映射到一个公共向量空间,并计算相似度。在推理阶段,大规模候选的 embeddings 可以提前离线计算,只需要在线计算 query 的 embeddings,并利用近似最近邻搜索等快速亚线性时间逼近方法来搜索与查询最相似的 Top-K 向量,从而在推理过程中实现可接受的检索延迟。

对于 QC 和 QS 匹配,我们使用双塔(two-tower)结构;对于 DQS 匹配,我们采用三塔(three-tower)结构。和 DPR 类似,我们使用向量的点积作为查询和候选之间的相似度。双塔和三塔模型的训练可以被形式化为相同的 metric learning 问题,目标是通过学习更好的表示函数来构建一个匹配空间,其中正对(positive pairs)之间的相似性高于负对(negative pairs)之间的相似性,我们的损失函数为所有正对的负对数似然之和。
为上下文匹配训练多塔模型的核心问题是找到查询和上下文(或查询和会话)的正对。在本文中,我们假设具有完全相同的响应的上下文是彼此的正样本,这是一种谨慎但可靠的策略。形式上,给定一个响应,如果有多个上下文的响应是,那么我们可以随机选择一个上下文作为查询,其他上下文是 的 positive contexts, 的 positive response。上下文和响应的负样本可以从一个 batch 或从数据库中随机抽样获得。

2.3 细到粗蒸馏

在多塔架构中,查询和候选通过它们的 embeddings 独立表示,这可能会导致信息丢失,并且它们单调的交互方式(内积)进一步限制了表达能力。与多塔模型相比,单塔模型将查询和候选作为连接进行输入,并允许在自注意力层中查询和候选之间的交叉注意。尽管参数较少,但单塔模型已被证明比多塔模型学习更多信息表示,因此它在细粒度响应中是首选。为了利用单塔模型学习到的更丰富的表达能力,我们将单塔模型的知识提炼成多塔模型以增强检索器。在蒸馏之前,我们需要训练基于单塔架构的教师模型。例如,训练单个编码器以区分查询和会话是否匹配,在形式上和 BERT 预训练中的下一句预测任务完全相同。
单塔和多塔之间的主要区别总结如下:单塔模型更具表现力,但效率较低且无法处理大规模候选者。主要原因是基于特征的计算相似度得分而不是内积的方法限制了离线缓存的能力。对于新查询,只能通过遍历计算与所有候选的相似度。巨大的延迟使得在粗粒度响应检索中无法使用单塔模型。
和知识蒸馏一样,精到粗蒸馏的方法是推动学生模型(多塔)学习教师模型(单塔)的预测标签作为软目标,而不是原来的 one-hot 标签。通过对教师模型预测的标签进行拟合,多塔模型可以在保持结构不变的情况下,从单塔模型中学习到更准确的相似度分数分布。

3 数据集构建

为了评估所提出模型的性能,我们基于 Reddit 评论数据 和 Twitter 语料库构建了两个新数据集。我们分别创建了一个训练集、一个多上下文测试集(MC test set) 和一个候选数据库。对于 Reddit,我们创建了一个额外的单上下文测试集(SC test set)。我们的候选数据库在 Twitter 和 Reddit 中的大小分别为 100 万和 1000 万,下表显示了详细的统计信息。

在 MC test set 中,每个 query 有多个对应的上下文,这确保了在候选数据库中可以找出对应该 query 的上下文;在 SC test set 中,每个 query 只有一个上下文。算法 1 中详细展示了 MC 和 SC test set 的构建方式。

4 实验

4.1 模型对比

对于基线,我们选择 BM25 作为基于稀疏表示的方法,该方法广泛用于文本匹配的实际场景中。基于 BM25 系统和两种匹配方法(QC 和 QS 匹配),可以得到两个检索器,分别记为 BM25-QC 和 BM25-QS。我们选择多塔模型作为基于密集表示的方法,其中双塔模型用于 QC 匹配和 QS 匹配(表示为 BE-QC 和 BE-QS),以及基于 DQS 匹配的三塔模型(表示为 TE-DQS)。此外,我们还报告了基于查询 - 响应(QR)匹配的结果,分别基于 BM25 系统和双塔模型(表示为 BM-QR 和 BE-QR)构建了两个检索器。
我们提出的 CFC 模型有三种变体,它们是 BE-QC、BE-QS 和 TE-DQS 的蒸馏版本,分别称为 CFC-QC、CFC-QS 和 CFC-DQS。每个学生模型的蒸馏都需要训练相应的教师模型。特别是,从 TE-DQS 到 CFC-DQS 的蒸馏需要两个教师模型,因为需要计算 query-context 和 query-response 之间的相似度。

4.2 评价指标

基于之前的研究,我们使用 Coverage@K 用于评估 Top-K 检索到的候选者是否包含真实响应。但是,Coverage@K 仅适用于评估 MC 测试集,由于上下文和响应之间的一对多关系,它无法评估整体检索质量。作为补充,我们提出了两个基于预训练模型的自动评估指标,即 Perplexity@K 和 Relevance@K。对于检索到的 Top-K 响应,我们使用 DialogGPT 用于计算给定查询的检索响应的条件困惑度。Perplexity@K 是 Top-K 检索到的响应的平均困惑度。除了 Perplexity,我们还评估查询和检索到的响应之间的相关性。我们使用 DialogRPT,它通过 human-vs-rand 任务对大规模人类反馈数据进行了预训练,该任务预测响应与给定上下文而不是随机响应对应的可能性有多大。Relevance@K 是查询和 Top-K 检索到的响应之间的平均预测相关度。Perplexity@K 和 Relevance@K 是基于所有 Top-K 检索响应的平均指标,因此它们可以反映整体检索质量。

4.3 实验结果

可以看出,密集检索器的性能远超 BM25 系统,显示出预训练模型丰富的语义信息,额外的训练可以提升检索器的性能。例如,与 BM25 系统相比,最佳的无蒸馏密集检索器(BE-QS)在三个指标上都有明显的提升。对于 Coverage@K,BE-QS 在 Reddit 和 Twitter 的 MC 测试集上的 Top-500 召回率与 BM25-QS 相比提高了 12.1% 和 17.4% 绝对值。对于 Perplexity@K,与 BM25-QS 相比,BE-QS 在 Reddit 的 MC 和 SC 测试集上的 Top-20 平均 perplexity 减少了 8.1 和 8.5 绝对值。对于 Relevance@K,与 BM25-QS 相比,BE-QS 在 MC 和 SC 测试集上的 Top-20 平均相关性比 BM25-QS 增加了 6.3% 和 6.5% 绝对值。Coverage@K 衡量检索者检索黄金响应的能力,而 Perplexity@K 和 Relevance@K 衡量整体检索质量。我们的结果表明三个指标的一致性,即召回率和整体检索质量呈正相关。


与上下文匹配相比,查询 - 响应(QR)匹配的检索召回率要低得多,这在 lan2020ultra 中也得到了验证。我们认为这是因为该响应通常是一个句子的短文本,并且包含的信息不足,并且与查询重叠的关键字可能很少。因此,在 RBD 系统中考虑上下文匹配是很重要的。
与 QC 匹配相比,QS 和 DQS 匹配应在实践中得到鼓励,因为响应提供了额外的信息。然而,BM25 系统并不能很好地利用响应信息,因为 BM25-QS 模型在 Reddit 和 Twitter 数据集上都没有显示出明显优于 BM25-QC 的优势。相比之下,密集检索模型可以有效地利用响应。例如,在 Reddit 的 MC 测试集的 Top-500 响应检索召回率方面,BE-QS 绝对优于 BE-QC 7.9%。对于 QS 和 DQS 匹配,性能差别不大。尤其是 Reddit 上的 SC 测试集和 Twitter 上的 MC 测试集,性能差异很小。DQS 的一个潜在优势是它可以利用正查询 - 响应对,其数量远大于正查询 - 上下文对。
我们进一步关注从细到粗蒸馏的性能增益。蒸馏模型在所有三个指标上都取得了明显的改进。一个明显的规律是,经过提炼的模型 K 越小,改进越大。以 Twitter 数据集为例,CFC 模型的 Top-500 检索召回率在提炼后增加 1.52.4,而 Top-1 检索召回率费率增加了 4.66.7。在 Perplexity@K 和 Relevance@K 上,我们的 CFC 模型具有相似的性能。小 K 时检索召回率的显着提高特别有利于细粒度的响应选择,因为它为 ranker 提供了更多的可能性,可以在看到更少的候选者的同时选择好的响应。上述结果表明,我们的学生模型受益于学习或从教师模型中继承细粒度的知识。
由于 DialogGPT 和 DialogRPT 没有在 Twitter 上进行预训练,Perplexity@K 和 Relevance@K 不适合评估 Twitter 数据集。因此,我们不为 Twitter 构建 SC 测试集。与 Twitter 相比,我们使用的 Reddit 数据集更大,具有更常见的多轮对话,并且检索难度明显更高。Twitter 上 Top-500 检索召回率达到 60%,而 Reddit 仅达到 20% 左右,这表明开放域对话中的粗粒度响应检索任务仍然存在很大挑战。

5 分析

5.1 数据库大小的影响

我们讨论了候选数据库大小对模型性能的影响。对于不同的候选数据库大小(从一百万到一千万),我们在 Reddit 的 MC 测试集上比较了 BM25-QS、BE-QS 和 CFC-QS 的 Coverage@500 指标(图 3)。可以看出 Coverage@500 随着数据库大小的增加呈现出缓慢下降的趋势。增加数据库大小不会使模型性能迅速下降,这说明了我们模型的有效性和鲁棒性。

5.2 人工评价

为了进一步评估和比较我们的模型,我们进行了人工评估实验。我们从 Reddit 数据集的 MC 和 SC 测试集(每个 500 个)中随机选择 1000 个查询,并分别通过 BM25-QS、BE-QS 和 CFC-QS 模型检索 Top-1 响应。三名众包工作者被要求对回答进行评分。对于每个查询,注释器将对三个模型的检索响应进行严格排名。我们在成对比较中报告平均排名分数(1 到 3 之间,越小越好)和获胜率。每两个注释器都有一定数量(大约 200 个)重叠的注释样本。为了评估评分者间的可靠性,采用 Cohen 的 kappa 系数。
下表分别报告每个模型的平均排名得分和模型之间的成对比较。CFC-QS 的平均排名得分最高,在大多数情况下 CFC-QS 可以击败 BE-QS 和 BM25(74.7%81.6%),这表明 CFC-QS 在 Top - 中占据明显优势 1 检索。所有 Cohen 的 Kappa 系数都在 0.6 和 0.7 之间,表明注释者达到中等一致性。人工评估的结果进一步验证了蒸馏给模型带来的性能提升。

5.3 效率评估

我们还比较了 BM25-QS 和 BE-QS 在 reddit MC 测试集上的检索延迟,它们分别代表了稀疏和密集检索器的效率。我们将批量大小固定为 32,并检索前 100 个最相似的候选者。在 FAISS 索引的帮助下,BE-QS 每批次的平均检索时间为 581.8ms。相比之下,BM25 系统使用文件索引的平均检索时间为 1882.6ms,约为 BE-QS 的三倍。这表明密集检索器在检索效率上也具有优势。
密集检索器相对较差的是它需要计算候选数据库的嵌入并建立 FAISS 索引,这相当耗时,BE-QS 使用 8 个 GPU 处理 1000 万个候选大约需要 9 个小时,而构建一个 BM25 索引只需要 10 分钟左右。

总结

本文的主要贡献有三方面:1)我们探索了开放域对话中的粗粒度响应选择的问题,并提出了上下文细到粗蒸馏模型;2)我们基于 Reddit 评论和 Twitter 语料库构建了两个新数据集,作为评估粗粒度响应选择任务的新基准;3)我们构建了广泛的实验来证明我们提出的模型在粗粒度响应选择中的有效性和潜力。



供稿丨 陈   伟
编辑丨 赵丽敏
责编丨 张霁雯

供稿人:陈伟 丨博士生 4 年级丨研究方向:对话系统 丨邮箱:chenwei18@fudan.edu.cn
继续滑动看下一个
复旦DISC
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存