损失表

损失函数在微调的交叉编码器模型性能中起着关键作用。遗憾的是,没有“一劳永逸”的损失函数。理想情况下,此表应通过将损失函数与您的数据格式匹配来帮助您缩小选择范围。

注意

您通常可以将一种训练数据格式转换为另一种,从而使更多损失函数适用于您的场景。例如,带有 class 标签的 (sentence_A, sentence_B) 对 可以通过采样相同或不同类别的句子转换为 (anchor, positive, negative) 三元组

此外,mine_hard_negatives() 可以轻松地将 (anchor, positive) 转换为

带有 output_format="triplet"(anchor, positive, negative) 三元组

  • 带有 output_format="n-tuple"(anchor, positive, negative_1, …, negative_n) 元组

  • 带有 output_format="labeled-pair"(anchor, passage, label) 标记对,其中负样本标签为 0,正样本标签为 1,

  • 带有 output_format="labeled-list"(anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) 三元组,其中负样本标签为 0,正样本标签为 1,

  • 输入

标签 模型输出标签数量 适当的损失函数 (句子A,句子B) 对
类别 类别数量 (锚点,正样本) 对 CrossEntropyLoss
(锚点,正样本/负样本) 对 1 MultipleNegativesRankingLoss
CachedMultipleNegativesRankingLoss
正样本为 1,负样本为 0 0 到 1 之间的浮点相似度分数 1 BinaryCrossEntropyLoss
类别 (锚点,正样本,负样本) 三元组 1 BinaryCrossEntropyLoss
(锚点,正样本,负样本_1, ..., 负样本_n) (锚点,正样本/负样本) 对 1 MultipleNegativesRankingLoss
CachedMultipleNegativesRankingLoss
(查询,[文档1, 文档2, ..., 文档N]) (锚点,正样本/负样本) 对 1 MultipleNegativesRankingLoss
CachedMultipleNegativesRankingLoss
[分数1, 分数2, ..., 分数N] 蒸馏 1
  1. LambdaLoss
  2. PListMLELoss
  3. ListNetLoss
  4. RankNetLoss
  5. ListMLELoss

这些损失函数专门设计用于将知识从一个模型蒸馏到另一个模型。例如,在微调小型模型使其表现更像大型且更强大的模型时,或者在微调模型使其成为多语言模型时。

文本

相似度分数 模型输出标签数量 (句子A,句子B) 对
类别 (查询,段落一,段落二) 三元组 MSELoss
gold_sim(查询,段落一) - gold_sim(查询,段落二) 常用损失函数 MarginMSELoss

在实践中,并非所有损失函数都同样常用。最常见的场景是:

带有 浮点相似度分数正样本为 1,负样本为 0(句子A,句子B) 对BinaryCrossEntropyLoss 是一个传统选项,仍然非常难以超越。

  • 无标签的 (锚点,正样本) 对:与 mine_hard_negatives 结合使用

  • output_format=”labeled-list” 时,则 LambdaLoss 常用于学习排序任务。

    • output_format=”labeled-pair” 时,则 BinaryCrossEntropyLoss 仍然是一个强大的选择。

    • 自定义损失函数

高级用户可以创建并使用自己的损失函数进行训练。自定义损失函数只有几个要求:

它们必须是 torch.nn.Module 的子类。

  • 它们在构造函数中必须将 model 作为第一个参数。

  • 它们必须实现一个 forward 方法,接受 inputslabels。前者是批次中文本的嵌套列表,外部列表中的每个元素代表训练数据集中的一列。您必须将这些文本组合成可以 1) 进行分词和 2) 输入到模型的对。后者是数据集中 labellabelsscorescores 列中的可选(列表形式的)标签张量。该方法必须返回单个损失值或损失组件字典(组件名称到损失值),这些组件将求和以生成最终损失值。当返回字典时,除了总损失外,还会单独记录各个组件,从而允许您监控损失的各个组件。

  • 为了获得自动模型卡生成的全面支持,您可能还希望实现:

一个返回损失参数字典的 get_config_dict 方法。

  • 一个 citation 属性,以便您的工作在使用该损失训练的所有模型中被引用。

  • 考虑检查现有损失函数,以了解损失函数通常是如何实现的。