Quora 重复问题
此文件夹包含演示如何训练用于信息检索的 SentenceTransformers 的脚本。作为一个简单示例,我们将使用 Quora 重复问题数据集。它包含超过 500,000 个句子,以及超过 400,000 对关于两个问题是否重复的成对标注。
在此数据集上训练的模型可用于挖掘重复问题,即给定大量句子(本例中为问题),识别所有重复的对。有关如何使用句子转换器挖掘重复问题/释义的示例,请参阅释义挖掘。这种方法可以扩展到数十万个句子。
您也可以为此任务训练和使用 CrossEncoder
模型。有关更多详细信息,请参阅Cross Encoder > 训练示例 > Quora 重复问题。
训练
选择正确的损失函数对于微调有用的模型至关重要。对于给定任务,两种损失函数特别适用:OnlineContrastiveLoss
和 MultipleNegativesRankingLoss
。
对比损失
有关完整的训练示例,请参阅 training_OnlineContrastiveLoss.py。
Quora Duplicates 数据集包含一个pair-class 子集,该子集由问题对和标签组成:1 表示重复,0 表示不同。
正如我们的损失概述所示,这允许我们使用ContrastiveLoss
。标签为 1 的相似对被拉近,使其在向量空间中靠近,而距离小于预定义裕度的不相似对则在向量空间中被推开。
OnlineContrastiveLoss
是一个改进版本。该损失函数会查找批处理中哪些负对的距离低于最大的正对,以及哪些正对的距离高于负对的最小距离。也就是说,该损失函数会自动检测批处理中的困难案例,并仅针对这些案例计算损失。
该损失函数可以这样使用
from datasets import load_dataset
train_dataset = load_dataset("sentence-transformers/quora-duplicates", "pair-class", split="train")
# => Dataset({
# features: ['sentence1', 'sentence2', 'label'],
# num_rows: 404290
# })
print(train_dataset[0])
# => {'sentence1': 'What is the step by step guide to invest in share market in india?', 'sentence2': 'What is the step by step guide to invest in share market?', 'label': 0}
train_loss = losses.OnlineContrastiveLoss(model=model, margin=0.5)
MultipleNegativesRankingLoss
有关完整示例,请参阅 training_MultipleNegativesRankingLoss.py。
MultipleNegativesRankingLoss
特别适用于信息检索/语义搜索。一个很好的优点是它只需要正对,即我们只需要重复问题的示例。有关该损失函数如何工作的更多信息,请参阅NLI > MultipleNegativesRankingLoss。
使用该损失函数很简单,不需要调整任何超参数
from datasets import load_dataset
train_dataset = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train")
# => Dataset({
# features: ['anchor', 'positive'],
# num_rows: 149263
# })
print(train_dataset[0])
# => {'anchor': 'Astrology: I am a Capricorn Sun Cap moon and cap rising...what does that say about me?', 'positive': "I'm a triple Capricorn (Sun, Moon and ascendant in Capricorn) What does this say about me?"}
train_loss = losses.MultipleNegativesRankingLoss(model)
由于“is_duplicate”是一个对称关系,我们不仅可以使用(锚点,正例),还可以使用(正例,锚点)来扩充我们的训练样本集。
from datasets import concatenate_datasets
train_dataset = concatenate_datasets([
train_dataset,
train_dataset.rename_columns({"anchor": "positive", "positive": "anchor"})
])
# Dataset({
# features: ['anchor', 'positive'],
# num_rows: 298526
# })
注意
增加批次大小通常会产生更好的结果,因为任务变得更难。从 100 个问题中识别正确的重复问题比从 10 个问题中识别更困难。因此,建议尽可能大地设置训练批次大小。我在 32 GB GPU 内存上以 350 的批次大小进行了训练。
注意
MultipleNegativesRankingLoss
仅在 (a_i, b_j) 且 j != i 实际上是负的、非重复问题对时才有效。在少数情况下,这个假设是错误的。但大多数情况下,如果我们随机抽取两个问题,它们不是重复的。如果您的数据集不能满足此属性,MultipleNegativesRankingLoss
可能无法很好地工作。
多任务学习
ContrastiveLoss
非常适用于对分类,即给定两个对,它们是否重复。它将负对在向量空间中推开,从而使重复和非重复对之间的区分效果良好。
MultipleNegativesRankingLoss
则主要从大量可能的候选集中减小正对之间的距离。然而,非重复问题之间的距离并没有那么大,因此这种损失函数在对分类方面效果不佳。
在 training_multi-task-learning.py 中,我演示了如何使用这两种损失函数来训练网络。核心代码是定义这两种损失函数并将它们传递给 fit 方法。
from datasets import load_dataset
from sentence_transformers.losses import ContrastiveLoss, MultipleNegativesRankingLoss
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformer
model_name = "stsb-distilbert-base"
model = SentenceTransformer(model_name)
# https://huggingface.co/datasets/sentence-transformers/quora-duplicates
mnrl_dataset = load_dataset(
"sentence-transformers/quora-duplicates", "triplet", split="train"
) # The "pair" subset also works
mnrl_train_dataset = mnrl_dataset.select(range(100000))
mnrl_eval_dataset = mnrl_dataset.select(range(100000, 101000))
mnrl_train_loss = MultipleNegativesRankingLoss(model=model)
# https://huggingface.co/datasets/sentence-transformers/quora-duplicates
cl_dataset = load_dataset("sentence-transformers/quora-duplicates", "pair-class", split="train")
cl_train_dataset = cl_dataset.select(range(100000))
cl_eval_dataset = cl_dataset.select(range(100000, 101000))
cl_train_loss = ContrastiveLoss(model=model, margin=0.5)
# Create the trainer & start training
trainer = SentenceTransformerTrainer(
model=model,
train_dataset={
"mnrl": mnrl_train_dataset,
"cl": cl_train_dataset,
},
eval_dataset={
"mnrl": mnrl_eval_dataset,
"cl": cl_eval_dataset,
},
loss={
"mnrl": mnrl_train_loss,
"cl": cl_train_loss,
},
)
trainer.train()
预训练模型
目前,有以下在 Quora 重复问题数据集上训练的模型可用
distilbert-base-nli-stsb-quora-ranking:我们扩展了 distilbert-base-nli-stsb-mean-tokens 模型,并使用 OnlineContrastiveLoss 和 MultipleNegativesRankingLoss 在 Quora 重复问题数据集上对其进行训练。有关代码,请参阅 training_multi-task-learning.py
distilbert-multilingual-nli-stsb-quora-ranking:distilbert-base-nli-stsb-quora-ranking 的多语言扩展。在 50 种语言的并行数据上训练。
您可以像这样加载和使用预训练模型
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("distilbert-base-nli-stsb-quora-ranking")