损失表
损失函数在微调的交叉编码器模型性能中起着关键作用。遗憾的是,没有“一劳永逸”的损失函数。理想情况下,此表应通过将损失函数与您的数据格式匹配来帮助您缩小选择范围。
注意
您通常可以将一种训练数据格式转换为另一种,从而使更多损失函数适用于您的场景。例如,带有 class
标签的 (sentence_A, sentence_B) 对
可以通过采样相同或不同类别的句子转换为 (anchor, positive, negative) 三元组
。
此外,mine_hard_negatives()
可以轻松地将 (anchor, positive)
转换为
带有 output_format="triplet"
的 (anchor, positive, negative) 三元组
,
带有
output_format="n-tuple"
的(anchor, positive, negative_1, …, negative_n) 元组
。带有
output_format="labeled-pair"
的(anchor, passage, label) 标记对
,其中负样本标签为 0,正样本标签为 1,带有
output_format="labeled-list"
的(anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) 三元组
,其中负样本标签为 0,正样本标签为 1,输入
标签 | 模型输出标签数量 | 适当的损失函数 | (句子A,句子B) 对 |
---|---|---|---|
类别 |
类别数量 |
(锚点,正样本) 对 |
CrossEntropyLoss |
无 |
(锚点,正样本/负样本) 对 |
1 |
MultipleNegativesRankingLoss CachedMultipleNegativesRankingLoss |
正样本为 1,负样本为 0 |
0 到 1 之间的浮点相似度分数 |
1 |
BinaryCrossEntropyLoss |
类别 |
(锚点,正样本,负样本) 三元组 |
1 |
BinaryCrossEntropyLoss |
(锚点,正样本,负样本_1, ..., 负样本_n) |
(锚点,正样本/负样本) 对 |
1 |
MultipleNegativesRankingLoss CachedMultipleNegativesRankingLoss |
(查询,[文档1, 文档2, ..., 文档N]) |
(锚点,正样本/负样本) 对 |
1 |
MultipleNegativesRankingLoss CachedMultipleNegativesRankingLoss |
[分数1, 分数2, ..., 分数N] |
蒸馏 |
1 |
这些损失函数专门设计用于将知识从一个模型蒸馏到另一个模型。例如,在微调小型模型使其表现更像大型且更强大的模型时,或者在微调模型使其成为多语言模型时。
文本
相似度分数 | 模型输出标签数量 | (句子A,句子B) 对 |
---|---|---|
类别 |
(查询,段落一,段落二) 三元组 |
MSELoss |
gold_sim(查询,段落一) - gold_sim(查询,段落二) |
常用损失函数 |
MarginMSELoss |
在实践中,并非所有损失函数都同样常用。最常见的场景是:
带有 浮点相似度分数
或 正样本为 1,负样本为 0
的 (句子A,句子B) 对
:BinaryCrossEntropyLoss
是一个传统选项,仍然非常难以超越。
无标签的
(锚点,正样本) 对
:与mine_hard_negatives
结合使用当
output_format=”labeled-list”
时,则LambdaLoss
常用于学习排序任务。当
output_format=”labeled-pair”
时,则BinaryCrossEntropyLoss
仍然是一个强大的选择。自定义损失函数
高级用户可以创建并使用自己的损失函数进行训练。自定义损失函数只有几个要求:
它们必须是 torch.nn.Module
的子类。
它们在构造函数中必须将
model
作为第一个参数。它们必须实现一个
forward
方法,接受inputs
和labels
。前者是批次中文本的嵌套列表,外部列表中的每个元素代表训练数据集中的一列。您必须将这些文本组合成可以 1) 进行分词和 2) 输入到模型的对。后者是数据集中label
、labels
、score
或scores
列中的可选(列表形式的)标签张量。该方法必须返回单个损失值或损失组件字典(组件名称到损失值),这些组件将求和以生成最终损失值。当返回字典时,除了总损失外,还会单独记录各个组件,从而允许您监控损失的各个组件。为了获得自动模型卡生成的全面支持,您可能还希望实现:
一个返回损失参数字典的 get_config_dict
方法。
一个
citation
属性,以便您的工作在使用该损失训练的所有模型中被引用。考虑检查现有损失函数,以了解损失函数通常是如何实现的。