损失概述
警告
要训练一个 SparseEncoder
,您需要 SpladeLoss
或 CSRLoss
,具体取决于架构。这些都是包装损失,它们在主损失函数之上添加了稀疏性正则化,主损失函数必须作为参数提供。唯一可以独立使用的损失是 SparseMSELoss
,因为它执行嵌入级蒸馏,通过直接复制教师模型的稀疏嵌入来确保稀疏性。
稀疏专用损失函数
SPLADE 损失
SpladeLoss
为 SPLADE (Sparse Lexical and Expansion) 模型实现了一个专用损失函数。它将一个主损失函数与正则化项结合起来以控制效率
支持下方提及的所有损失作为主损失,但主要支持三种损失类型:
SparseMultipleNegativesRankingLoss
、SparseMarginMSELoss
和SparseDistillKLDivLoss
。默认使用
FlopsLoss
进行正则化以控制稀疏性,但也支持自定义正则化器。通过对查询和文档表示进行正则化,平衡了有效性(通过主损失)与效率。
允许通过
query_regularizer
和document_regularizer
参数为查询和文档使用不同的正则化器,从而实现对不同类型输入的稀疏性模式的精细控制。通过
query_regularizer_threshold
和document_regularizer_threshold
参数支持查询和文档的单独阈值,允许每种输入类型具有不同的稀疏性严格程度。
CSR 损失
如果您正在使用 SparseAutoEncoder
模块,那么您必须使用 CSRLoss
(对比稀疏表示损失)。它结合了两个组件
一个重建损失
CSRReconstructionLoss
,确保稀疏表示能够忠实地重建原始嵌入。一个主损失,在论文中是一个使用
SparseMultipleNegativesRankingLoss
的对比学习组件,确保语义相似的句子具有相似的表示。但理论上,可以像SpladeLoss
一样,将下方提及的所有损失用作主损失。
损失表
损失函数对微调模型的性能起着至关重要的作用。遗憾的是,没有“一刀切”的损失函数。理想情况下,此表应通过将损失函数与您的数据格式匹配来帮助缩小您的选择范围。
注意
您通常可以将一种训练数据格式转换为另一种,从而使更多损失函数适用于您的场景。例如,带有
class
标签的(sentence_A, sentence_B) 对
可以通过采样具有相同或不同类别的句子转换为(anchor, positive, negative) 三元组
。
注意
SentenceTransformer > 损失概述 中此处出现的带有 Sparse
前缀的损失函数与其密集版本相同。该前缀仅用于指示哪些损失可以用作主损失来训练 SparseEncoder
输入 | 标签 | 合适的损失函数 |
---|---|---|
(anchor, positive) 对 |
无 |
SparseMultipleNegativesRankingLoss |
(sentence_A, sentence_B) 对 |
0 到 1 之间的浮点相似度分数 |
SparseCoSENTLoss SparseAnglELoss SparseCosineSimilarityLoss |
(anchor, positive, negative) 三元组 |
无 |
SparseMultipleNegativesRankingLoss SparseTripletLoss |
(anchor, positive, negative_1, ..., negative_n) |
无 |
SparseMultipleNegativesRankingLoss |
蒸馏
这些损失函数专门设计用于将知识从一个模型蒸馏到另一个模型。这在训练稀疏嵌入模型时相当常用。
文本 | 标签 | 合适的损失函数 |
---|---|---|
句子 |
模型句子嵌入 |
SparseMSELoss |
sentence_1, sentence_2, ..., sentence_N |
模型句子嵌入 |
SparseMSELoss |
(query, passage_one, passage_two) 三元组 |
gold_sim(query, passage_one) - gold_sim(query, passage_two) |
SparseMarginMSELoss |
(query, positive, negative) 三元组 |
[gold_sim(query, positive), gold_sim(query, negative)] |
SparseDistillKLDivLoss SparseMarginMSELoss |
(query, positive, negative_1, ..., negative_n) |
[gold_sim(query, positive) - gold_sim(query, negative_i) for i in 1..n] |
SparseMarginMSELoss |
(query, positive, negative_1, ..., negative_n) |
[gold_sim(query, positive), gold_sim(query, negative_i)...] |
SparseDistillKLDivLoss SparseMarginMSELoss |
常用损失函数
实际上,并非所有损失函数的使用频率都相同。最常见的场景是
没有标签的
(anchor, positive) 对
:SparseMultipleNegativesRankingLoss
(也称为 InfoNCE 或批内负样本损失)常用于训练表现最佳的嵌入模型。这种数据通常获取成本相对较低,且模型通常表现非常好。对于我们的稀疏检索任务,这种格式与SpladeLoss
或CSRLoss
配合使用效果良好,两者通常都使用 InfoNCE 作为其底层损失函数。(query, positive, negative_1, ..., negative_n)
格式:这种包含多个负样本的结构在配置了SparseMarginMSELoss
的SpladeLoss
中特别有效,尤其是在教师模型提供相似度分数的知识蒸馏场景中。最强的模型是使用SparseDistillKLDivLoss
或SparseMarginMSELoss
等蒸馏损失进行训练的。
自定义损失函数
高级用户可以创建并使用自己的损失函数进行训练。自定义损失函数只有几个要求
它们必须是
torch.nn.Module
的子类。它们在构造函数中必须将
model
作为第一个参数。它们必须实现一个
forward
方法,该方法接受sentence_features
和labels
。前者是一个标记化批次的列表,每列一个元素。这些标记化批次可以直接馈送到正在训练的model
以产生嵌入。后者是一个可选的标签张量。该方法必须返回一个单一的损失值或一个损失组件的字典(组件名称到损失值),这些损失组件将被求和以产生最终损失值。当返回字典时,除了求和的损失之外,各个组件将单独记录,以便您能够监控损失的各个组件。
为了获得自动模型卡生成的完全支持,您可能还希望实现
一个
get_config_dict
方法,它返回一个损失参数的字典。一个
citation
属性,以便您的工作在所有使用该损失训练的模型中被引用。
考虑检查现有损失函数,以了解损失函数通常是如何实现的。