创建自定义模型

Sentence Transformer 模型的结构

Sentence Transformer 模型由一系列模块(文档)组成,这些模块按顺序执行。最常见的架构是 Transformer 模块、一个 Pooling 模块,以及可选的 Dense 模块和/或一个 Normalize 模块的组合。

  • Transformer:此模块负责处理输入文本并生成语境化嵌入。

  • Pooling:此模块通过聚合嵌入来降低 Transformer 模块输出的维度。常见的池化策略包括均值池化和 CLS 池化。

  • Dense:此模块包含一个线性层,用于后处理 Pooling 模块的嵌入输出。

  • Normalize:此模块规范化来自前一层的嵌入。

例如,流行的 all-MiniLM-L6-v2 模型也可以通过初始化构成该模型的 3 个特定模块来加载

from sentence_transformers import models, SentenceTransformer

transformer = models.Transformer("sentence-transformers/all-MiniLM-L6-v2", max_seq_length=256)
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
normalize = models.Normalize()

model = SentenceTransformer(modules=[transformer, pooling, normalize])

保存 Sentence Transformer 模型

每当保存 Sentence Transformer 模型时,会生成三种类型的文件

  • modules.json:此文件包含用于重建模型的模块名称、路径和类型列表。

  • config_sentence_transformers.json:此文件包含 Sentence Transformer 模型的一些配置选项,包括保存的提示、模型的相似度函数,以及模型作者使用的 Sentence Transformer 包版本。

  • 模块特定文件:每个模块都保存在单独的子文件夹中,以模块索引和模型名称命名(例如,1_Pooling, 2_Normalize),但第一个模块如果其 save_in_root 属性设置为 True,则可以保存在根目录中。在 Sentence Transformers 中,TransformerCLIPModel 模块就是这种情况。大多数模块文件夹都包含一个 config.json(或对于 Transformer 模块,则是 sentence_bert_config.json)文件,该文件存储传递给该模块的关键字参数的默认值。因此,一个 sentence_bert_config.json 文件,其内容如下:

    {
      "max_seq_length": 4096,
      "do_lower_case": false
    }
    

    意味着 Transformer 模块将使用 max_seq_length=4096do_lower_case=False 进行初始化。

因此,如果我在上一个片段中的 model 上调用 SentenceTransformer.save_pretrained("local-all-MiniLM-L6-v2"),则会生成以下文件

local-all-MiniLM-L6-v2/
├── 1_Pooling
│   └── config.json
├── 2_Normalize
├── README.md
├── config.json
├── config_sentence_transformers.json
├── model.safetensors
├── modules.json
├── sentence_bert_config.json
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer_config.json
└── vocab.txt

这包含一个 modules.json 文件,内容如下:

[
  {
    "idx": 0,
    "name": "0",
    "path": "",
    "type": "sentence_transformers.models.Transformer"
  },
  {
    "idx": 1,
    "name": "1",
    "path": "1_Pooling",
    "type": "sentence_transformers.models.Pooling"
  },
  {
    "idx": 2,
    "name": "2",
    "path": "2_Normalize",
    "type": "sentence_transformers.models.Normalize"
  }
]

以及一个 config_sentence_transformers.json 文件,内容如下:

{
  "__version__": {
    "sentence_transformers": "3.0.1",
    "transformers": "4.43.4",
    "pytorch": "2.5.0"
  },
  "prompts": {},
  "default_prompt_name": null,
  "similarity_fn_name": null
}

此外,1_Pooling 目录包含 Pooling 模块的配置文件,而 2_Normalize 目录是空的,因为 Normalize 模块不需要任何配置。sentence_bert_config.json 文件包含 Transformer 模块的配置,并且此模块还在根目录中保存了许多与分词器和模型本身相关的文件。

加载 Sentence Transformer 模型

要从已保存的模型目录加载 Sentence Transformer 模型,会读取 modules.json 文件以确定构成模型的模块。每个模块都会使用相应模块目录中存储的配置进行初始化,然后使用加载的模块实例化 SentenceTransformer 类。

基于 Transformers 模型的 Sentence Transformer 模型

当您使用纯 Transformers 模型(例如 BERT、RoBERTa、DistilBERT、T5)初始化 Sentence Transformer 模型时,Sentence Transformers 默认会创建一个 Transformer 模块和一个 Mean Pooling 模块。这提供了一种简单的方法来利用预训练语言模型进行句子嵌入。

具体来说,这两个代码片段是相同的

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("bert-base-uncased")
from sentence_transformers import models, SentenceTransformer

transformer = models.Transformer("bert-base-uncased")
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[transformer, pooling])

高级:自定义模块

输入模块

管道中的第一个模块称为输入模块。它负责对输入文本进行分词,并为后续模块生成输入特征。输入模块可以是任何实现了 InputModule 类(它是 Module 类的子类)的模块。

它有三个你需要实现抽象方法

  • 一个 forward() 方法,该方法接受一个 features 字典,其中包含 input_idsattention_masktoken_type_idstoken_embeddingssentence_embedding 等键,具体取决于模块在模型管道中的位置。

  • 一个 save() 方法,用于将模块的配置和可选的权重保存到指定目录。

  • 一个 tokenize() 方法,该方法接受输入列表并返回一个字典,其中包含 input_idsattention_masktoken_type_idspixel_values 等键。此字典将传递给模块的 forward 方法。

另外,您还可以实现以下方法

  • 一个 load() 静态方法,该方法接受 model_name_or_path 参数、用于从 Hugging Face 加载的关键字参数(subfoldertokencache_folder 等)和模块 kwargs(model_kwargstrust_remote_codebackend 等),并根据该目录或模型名称中的模块配置初始化模块。

  • 一个 get_sentence_embedding_dimension() 方法,返回模块生成的句子嵌入的维度。如果模块生成嵌入或更新嵌入的维度,则需要此方法。

  • 一个 get_max_seq_length() 方法,返回模块可以处理的最大序列长度。仅当模块处理输入文本时才需要。

后续模块

管道中的后续模块称为非输入模块。它们负责处理输入模块生成的输入特征并生成最终的句子嵌入。非输入模块可以是任何实现了 Module 类的模块。

它有两个你需要实现抽象方法

  • 一个 forward() 方法,该方法接受一个 features 字典,其中包含 input_idsattention_masktoken_type_idstoken_embeddingssentence_embedding 等键,具体取决于模块在模型管道中的位置。

  • 一个 save() 方法,用于将模块的配置和可选的权重保存到指定目录。

另外,您还可以实现以下方法

  • 一个 load() 静态方法,该方法接受 model_name_or_path 参数、用于从 Hugging Face 加载的关键字参数(subfoldertokencache_folder 等)和模块 kwargs(model_kwargstrust_remote_codebackend 等),并根据该目录或模型名称中的模块配置初始化模块。

  • 一个 get_sentence_embedding_dimension() 方法,返回模块生成的句子嵌入的维度。如果模块生成嵌入或更新嵌入的维度,则需要此方法。

示例模块

例如,我们可以通过实现一个自定义模块来创建自定义池化方法。

# decay_pooling.py

import torch
from sentence_transformers.models import Module


class DecayMeanPooling(Module):
    config_keys: list[str] = ["dimension", "decay"]

    def __init__(self, dimension: int, decay: float = 0.95, **kwargs) -> None:
        super(DecayMeanPooling, self).__init__()
        self.dimension = dimension
        self.decay = decay

    def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
        # This module is expected to be used after some modules that provide "token_embeddings"
        # and "attention_mask" in the features dictionary.
        token_embeddings = features["token_embeddings"]
        attention_mask = features["attention_mask"].unsqueeze(-1)

        # Apply the attention mask to filter away padding tokens
        token_embeddings = token_embeddings * attention_mask
        # Calculate mean of token embeddings
        sentence_embeddings = token_embeddings.sum(1) / attention_mask.sum(1)
        # Apply exponential decay
        importance_per_dim = self.decay ** torch.arange(
            sentence_embeddings.size(1), device=sentence_embeddings.device
        )
        features["sentence_embedding"] = sentence_embeddings * importance_per_dim
        return features

    def get_sentence_embedding_dimension(self) -> int:
        return self.dimension

    def save(self, output_path, *args, safe_serialization=True, **kwargs) -> None:
        self.save_config(output_path)

    # The `load` method by default loads the config.json file from the model directory
    # and initializes the class with the loaded parameters, i.e. the `config_keys`.
    # This works for us, so no need to override it.

注意

建议在 __init__forwardsaveloadtokenize 方法中添加 **kwargs,以确保这些方法与 Sentence Transformers 库未来的更新保持兼容。

现在可以将其用作 Sentence Transformer 模型中的模块

from sentence_transformers import models, SentenceTransformer
from decay_pooling import DecayMeanPooling

transformer = models.Transformer("bert-base-uncased", max_seq_length=256)
decay_mean_pooling = DecayMeanPooling(transformer.get_word_embedding_dimension(), decay=0.99)
normalize = models.Normalize()

model = SentenceTransformer(modules=[transformer, decay_mean_pooling, normalize])
print(model)
"""
SentenceTransformer(
    (0): Transformer({'max_seq_length': 256, 'do_lower_case': False, 'architecture': 'BertModel'})
    (1): DecayMeanPooling()
    (2): Normalize()
)
"""

texts = [
    "Hello, World!",
    "The quick brown fox jumps over the lazy dog.",
    "I am a sentence that is used for testing purposes.",
    "This is a test sentence.",
    "This is another test sentence.",
]
embeddings = model.encode(texts)
print(embeddings.shape)
# [5, 768]

您可以使用 SentenceTransformer.save_pretrained 保存此模型,这将生成一个 modules.json 文件,内容如下:

[
  {
    "idx": 0,
    "name": "0",
    "path": "",
    "type": "sentence_transformers.models.Transformer"
  },
  {
    "idx": 1,
    "name": "1",
    "path": "1_DecayMeanPooling",
    "type": "decay_pooling.DecayMeanPooling"
  },
  {
    "idx": 2,
    "name": "2",
    "path": "2_Normalize",
    "type": "sentence_transformers.models.Normalize"
  }
]

为了确保能够导入 decay_pooling.DecayMeanPooling,您应该将 decay_pooling.py 文件复制到保存模型的目录中。如果您将模型推送到 Hugging Face Hub,那么您也应该将 decay_pooling.py 文件上传到模型的仓库中。然后,每个人都可以通过调用 SentenceTransformer("your-username/your-model-id", trust_remote_code=True) 来使用您的自定义模块。

注意

使用存储在 Hugging Face Hub 上的远程代码的自定义模块要求您的用户在加载模型时将 trust_remote_code 指定为 True。这是一项安全措施,旨在防止远程代码执行攻击。

如果您的模型和自定义建模代码位于 Hugging Face Hub 上,那么将自定义模块分离到单独的仓库中可能更有意义。这样,您只需维护一个自定义模块的实现,并且可以在多个模型中重复使用它。您可以通过更新 modules.json 文件中的 type 来实现这一点,使其包含自定义模块存储的仓库路径,例如 {repository_id}--{dot_path_to_module}。例如,如果 decay_pooling.py 文件存储在名为 my-user/my-model-implementation 的仓库中,并且模块名为 DecayMeanPooling,那么 modules.json 文件可能如下所示:

[
  {
    "idx": 0,
    "name": "0",
    "path": "",
    "type": "sentence_transformers.models.Transformer"
  },
  {
    "idx": 1,
    "name": "1",
    "path": "1_DecayMeanPooling",
    "type": "my-user/my-model-implementation--decay_pooling.DecayMeanPooling"
  },
  {
    "idx": 2,
    "name": "2",
    "path": "2_Normalize",
    "type": "sentence_transformers.models.Normalize"
  }
]

高级:自定义模块中的关键字参数直传

如果您希望您的用户能够通过 SentenceTransformer.encode 方法指定自定义关键字参数,那么您可以将它们的名称添加到 modules.json 文件中。例如,如果您的模块在用户指定 task 关键字参数时应表现不同,那么您的 modules.json 可能会如下所示:

[
  {
    "idx": 0,
    "name": "0",
    "path": "",
    "type": "custom_transformer.CustomTransformer",
    "kwargs": ["task"]
  },
  {
    "idx": 1,
    "name": "1",
    "path": "1_Pooling",
    "type": "sentence_transformers.models.Pooling"
  },
  {
    "idx": 2,
    "name": "2",
    "path": "2_Normalize",
    "type": "sentence_transformers.models.Normalize"
  }
]

然后,您可以在自定义模块的 forward 方法中访问 task 关键字参数

from sentence_transformers.models import Transformer

class CustomTransformer(Transformer):
    def forward(self, features: dict[str, torch.Tensor], task: Optional[str] = None, **kwargs) -> dict[str, torch.Tensor]:
        if task == "default":
            # Do something
        else:
            # Do something else
        return features

这样,用户在调用 SentenceTransformer.encode 时,可以指定 task 关键字参数。

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("your-username/your-model-id", trust_remote_code=True)
texts = [...]
model.encode(texts, task="default")