损失概述
损失表
损失函数在微调的 Cross Encoder 模型的性能中起着至关重要的作用。遗憾的是,没有“一劳永逸”的损失函数。理想情况下,此表应通过将损失函数与您的数据格式匹配,帮助您缩小损失函数的选择范围。
注意
您通常可以将一种训练数据格式转换为另一种,从而使更多损失函数适用于您的场景。例如,带有 class
标签的 (sentence_A, sentence_B) pairs
可以通过对具有相同或不同类别的句子进行采样来转换为 (anchor, positive, negative) triplets
。
此外,mine_hard_negatives()
可以轻松地用于将 (anchor, positive)
转换为
(anchor, positive, negative) triplets
,其中output_format="triplet"
,(anchor, positive, negative_1, …, negative_n) tuples
,其中output_format="n-tuple"
。(anchor, passage, label) labeled pairs
,标签 0 表示负例,1 表示正例,其中output_format="labeled-pair"
,(anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) triplets
,标签 0 表示负例,1 表示正例,其中output_format="labeled-list"
,
输入 | 标签 | 模型输出标签数量 | 合适的损失函数 |
---|---|---|---|
(sentence_A, sentence_B) 对 |
类别 |
num_classes |
CrossEntropyLoss |
(anchor, positive) 对 |
无 |
1 |
MultipleNegativesRankingLoss CachedMultipleNegativesRankingLoss |
(anchor, positive/negative) 对 |
如果是正例,则为 1,如果是负例,则为 0 |
1 |
BinaryCrossEntropyLoss |
(sentence_A, sentence_B) 对 |
介于 0 和 1 之间的浮点相似度分数 |
1 |
BinaryCrossEntropyLoss |
(anchor, positive, negative) 三元组 |
无 |
1 |
MultipleNegativesRankingLoss CachedMultipleNegativesRankingLoss |
(anchor, positive, negative_1, ..., negative_n) |
无 |
1 |
MultipleNegativesRankingLoss CachedMultipleNegativesRankingLoss |
(query, [doc1, doc2, ..., docN]) |
[score1, score2, ..., scoreN] |
1 |
蒸馏
这些损失函数专门设计用于将知识从一个模型提炼到另一个模型时使用。例如,当微调一个小模型以使其行为更像更大更强的模型时,或者当微调模型以使其成为多语言模型时。
文本 | 标签 | 合适的损失函数 |
---|---|---|
(sentence_A, sentence_B) 对 |
相似度分数 |
MSELoss |
(query, passage_one, passage_two) 三元组 |
gold_sim(query, passage_one) - gold_sim(query, passage_two) |
MarginMSELoss |
常用损失函数
实际上,并非所有损失函数的使用频率都相同。最常见的场景是
带有
float similarity score
或1 if positive, 0 if negative
的(sentence_A, sentence_B) pairs
:BinaryCrossEntropyLoss
是一个传统的选项,仍然非常难以超越。不带任何标签的
(anchor, positive) pairs
:与mine_hard_negatives
结合使用使用
output_format=”labeled-list”
,然后LambdaLoss
经常用于 learning-to-rank 任务。使用
output_format=”labeled-pair”
,然后BinaryCrossEntropyLoss
仍然是一个强大的选项。
自定义损失函数
高级用户可以使用自己的损失函数创建和训练。自定义损失函数只有几个要求
它们必须是
torch.nn.Module
的子类。它们必须将
model
作为构造函数中的第一个参数。它们必须实现一个接受
inputs
和labels
的forward
方法。前者是批次中文本的嵌套列表,其中外层列表中的每个元素代表训练数据集中的一列。您必须将这些文本组合成对,这些文本可以 1) 被标记化和 2) 被馈送到模型。后者是来自数据集中的label
、labels
、score
或scores
列的可选(列表)张量标签。该方法必须返回单个损失值。
为了获得对自动模型卡生成的完全支持,您可能还希望实现
一个
get_config_dict
方法,该方法返回损失参数的字典。一个
citation
属性,以便您的工作在所有使用该损失训练的模型中被引用。
考虑检查现有的损失函数,以了解损失函数通常是如何实现的。