损失概述

损失表

损失函数在微调的 Cross Encoder 模型的性能中起着至关重要的作用。遗憾的是,没有“一劳永逸”的损失函数。理想情况下,此表应通过将损失函数与您的数据格式匹配,帮助您缩小损失函数的选择范围。

注意

您通常可以将一种训练数据格式转换为另一种,从而使更多损失函数适用于您的场景。例如,带有 class 标签的 (sentence_A, sentence_B) pairs 可以通过对具有相同或不同类别的句子进行采样来转换为 (anchor, positive, negative) triplets

此外,mine_hard_negatives() 可以轻松地用于将 (anchor, positive) 转换为

  • (anchor, positive, negative) triplets,其中 output_format="triplet"

  • (anchor, positive, negative_1, …, negative_n) tuples,其中 output_format="n-tuple"

  • (anchor, passage, label) labeled pairs,标签 0 表示负例,1 表示正例,其中 output_format="labeled-pair"

  • (anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) triplets,标签 0 表示负例,1 表示正例,其中 output_format="labeled-list"

输入 标签 模型输出标签数量 合适的损失函数
(sentence_A, sentence_B) 对 类别 num_classes CrossEntropyLoss
(anchor, positive) 对 1 MultipleNegativesRankingLoss
CachedMultipleNegativesRankingLoss
(anchor, positive/negative) 对 如果是正例,则为 1,如果是负例,则为 0 1 BinaryCrossEntropyLoss
(sentence_A, sentence_B) 对 介于 0 和 1 之间的浮点相似度分数 1 BinaryCrossEntropyLoss
(anchor, positive, negative) 三元组 1 MultipleNegativesRankingLoss
CachedMultipleNegativesRankingLoss
(anchor, positive, negative_1, ..., negative_n) 1 MultipleNegativesRankingLoss
CachedMultipleNegativesRankingLoss
(query, [doc1, doc2, ..., docN]) [score1, score2, ..., scoreN] 1
  1. LambdaLoss
  2. PListMLELoss
  3. ListNetLoss
  4. RankNetLoss
  5. ListMLELoss

蒸馏

这些损失函数专门设计用于将知识从一个模型提炼到另一个模型时使用。例如,当微调一个小模型以使其行为更像更大更强的模型时,或者当微调模型以使其成为多语言模型时。

文本 标签 合适的损失函数
(sentence_A, sentence_B) 对 相似度分数 MSELoss
(query, passage_one, passage_two) 三元组 gold_sim(query, passage_one) - gold_sim(query, passage_two) MarginMSELoss

常用损失函数

实际上,并非所有损失函数的使用频率都相同。最常见的场景是

  • 带有 float similarity score1 if positive, 0 if negative(sentence_A, sentence_B) pairsBinaryCrossEntropyLoss 是一个传统的选项,仍然非常难以超越。

  • 不带任何标签的 (anchor, positive) pairs:与 mine_hard_negatives 结合使用

    • 使用 output_format=”labeled-list”,然后 LambdaLoss 经常用于 learning-to-rank 任务。

    • 使用 output_format=”labeled-pair”,然后 BinaryCrossEntropyLoss 仍然是一个强大的选项。

自定义损失函数

高级用户可以使用自己的损失函数创建和训练。自定义损失函数只有几个要求

  • 它们必须是 torch.nn.Module 的子类。

  • 它们必须将 model 作为构造函数中的第一个参数。

  • 它们必须实现一个接受 inputslabelsforward 方法。前者是批次中文本的嵌套列表,其中外层列表中的每个元素代表训练数据集中的一列。您必须将这些文本组合成对,这些文本可以 1) 被标记化和 2) 被馈送到模型。后者是来自数据集中的 labellabelsscorescores 列的可选(列表)张量标签。该方法必须返回单个损失值。

为了获得对自动模型卡生成的完全支持,您可能还希望实现

  • 一个 get_config_dict 方法,该方法返回损失参数的字典。

  • 一个 citation 属性,以便您的工作在所有使用该损失训练的模型中被引用。

考虑检查现有的损失函数,以了解损失函数通常是如何实现的。