Cross-Encoders

SentenceTransformers 也支持训练 Cross-Encoders 用于句子对评分和句子对分类任务的选项。有关 Cross-Encoders 是什么以及 Cross- 和 Bi-Encoders 之间差异的更多详细信息,请参阅 Cross-Encoders

示例

请参阅以下示例,了解如何训练 Cross-Encoders

训练 CrossEncoders

CrossEncoder 类是 Hugging Face AutoModelForSequenceClassification 的包装器,但带有一些使其训练和预测分数更简单的方法。保存的模型与 Hugging Face 100% 兼容,也可以使用它们的类加载。

首先,您需要一些句子对数据。您可以具有连续的分数,例如

from sentence_transformers import InputExample

train_samples = [
    InputExample(texts=["sentence1", "sentence2"], label=0.3),
    InputExample(texts=["Another", "pair"], label=0.8),
]

或者您具有不同的类,如 training_nli.py 示例中所示

from sentence_transformers import InputExample

label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
train_samples = [
    InputExample(texts=["sentence1", "sentence2"], label=label2int["neutral"]),
    InputExample(texts=["Another", "pair"], label=label2int["entailment"]),
]

然后,您定义基础模型和标签数量。您可以采用任何与 AutoModel 兼容的 Hugging Face 预训练模型

model = CrossEncoder('distilroberta-base', num_labels=1)

对于二元任务和具有连续分数的任务(如 STS),我们将 num_labels 设置为 1。对于分类任务,我们将其设置为我们拥有的标签数量。

我们通过调用 model.fit() 开始训练

model.fit(
    train_dataloader=train_dataloader,
    evaluator=evaluator,
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    output_path=model_save_path,
)