创建自定义模型
Sentence Transformer 模型的结构
Sentence Transformer 模型由一系列模块(文档)组成,这些模块按顺序执行。最常见的架构是 Transformer
模块、一个 Pooling
模块,以及可选的 Dense
模块和/或一个 Normalize
模块的组合。
Transformer
:此模块负责处理输入文本并生成语境化嵌入。Pooling
:此模块通过聚合嵌入来降低 Transformer 模块输出的维度。常见的池化策略包括均值池化和 CLS 池化。Dense
:此模块包含一个线性层,用于后处理 Pooling 模块的嵌入输出。Normalize
:此模块规范化来自前一层的嵌入。
例如,流行的 all-MiniLM-L6-v2 模型也可以通过初始化构成该模型的 3 个特定模块来加载
from sentence_transformers import models, SentenceTransformer
transformer = models.Transformer("sentence-transformers/all-MiniLM-L6-v2", max_seq_length=256)
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
normalize = models.Normalize()
model = SentenceTransformer(modules=[transformer, pooling, normalize])
保存 Sentence Transformer 模型
每当保存 Sentence Transformer 模型时,会生成三种类型的文件
modules.json
:此文件包含用于重建模型的模块名称、路径和类型列表。config_sentence_transformers.json
:此文件包含 Sentence Transformer 模型的一些配置选项,包括保存的提示、模型的相似度函数,以及模型作者使用的 Sentence Transformer 包版本。模块特定文件:每个模块都保存在单独的子文件夹中,以模块索引和模型名称命名(例如,
1_Pooling
,2_Normalize
),但第一个模块如果其save_in_root
属性设置为True
,则可以保存在根目录中。在 Sentence Transformers 中,Transformer
和CLIPModel
模块就是这种情况。大多数模块文件夹都包含一个config.json
(或对于Transformer
模块,则是sentence_bert_config.json
)文件,该文件存储传递给该模块的关键字参数的默认值。因此,一个sentence_bert_config.json
文件,其内容如下:{ "max_seq_length": 4096, "do_lower_case": false }
意味着
Transformer
模块将使用max_seq_length=4096
和do_lower_case=False
进行初始化。
因此,如果我在上一个片段中的 model
上调用 SentenceTransformer.save_pretrained("local-all-MiniLM-L6-v2")
,则会生成以下文件
local-all-MiniLM-L6-v2/
├── 1_Pooling
│ └── config.json
├── 2_Normalize
├── README.md
├── config.json
├── config_sentence_transformers.json
├── model.safetensors
├── modules.json
├── sentence_bert_config.json
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer_config.json
└── vocab.txt
这包含一个 modules.json
文件,内容如下:
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "sentence_transformers.models.Transformer"
},
{
"idx": 1,
"name": "1",
"path": "1_Pooling",
"type": "sentence_transformers.models.Pooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
以及一个 config_sentence_transformers.json
文件,内容如下:
{
"__version__": {
"sentence_transformers": "3.0.1",
"transformers": "4.43.4",
"pytorch": "2.5.0"
},
"prompts": {},
"default_prompt_name": null,
"similarity_fn_name": null
}
此外,1_Pooling
目录包含 Pooling
模块的配置文件,而 2_Normalize
目录是空的,因为 Normalize
模块不需要任何配置。sentence_bert_config.json
文件包含 Transformer
模块的配置,并且此模块还在根目录中保存了许多与分词器和模型本身相关的文件。
加载 Sentence Transformer 模型
要从已保存的模型目录加载 Sentence Transformer 模型,会读取 modules.json
文件以确定构成模型的模块。每个模块都会使用相应模块目录中存储的配置进行初始化,然后使用加载的模块实例化 SentenceTransformer 类。
基于 Transformers 模型的 Sentence Transformer 模型
当您使用纯 Transformers 模型(例如 BERT、RoBERTa、DistilBERT、T5)初始化 Sentence Transformer 模型时,Sentence Transformers 默认会创建一个 Transformer 模块和一个 Mean Pooling 模块。这提供了一种简单的方法来利用预训练语言模型进行句子嵌入。
具体来说,这两个代码片段是相同的
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("bert-base-uncased")
from sentence_transformers import models, SentenceTransformer
transformer = models.Transformer("bert-base-uncased")
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[transformer, pooling])
高级:自定义模块
输入模块
管道中的第一个模块称为输入模块。它负责对输入文本进行分词,并为后续模块生成输入特征。输入模块可以是任何实现了 InputModule
类(它是 Module
类的子类)的模块。
它有三个你需要实现抽象方法
一个
forward()
方法,该方法接受一个features
字典,其中包含input_ids
、attention_mask
、token_type_ids
、token_embeddings
和sentence_embedding
等键,具体取决于模块在模型管道中的位置。一个
save()
方法,用于将模块的配置和可选的权重保存到指定目录。一个
tokenize()
方法,该方法接受输入列表并返回一个字典,其中包含input_ids
、attention_mask
、token_type_ids
、pixel_values
等键。此字典将传递给模块的forward
方法。
另外,您还可以实现以下方法
一个
load()
静态方法,该方法接受model_name_or_path
参数、用于从 Hugging Face 加载的关键字参数(subfolder
、token
、cache_folder
等)和模块 kwargs(model_kwargs
、trust_remote_code
、backend
等),并根据该目录或模型名称中的模块配置初始化模块。一个
get_sentence_embedding_dimension()
方法,返回模块生成的句子嵌入的维度。如果模块生成嵌入或更新嵌入的维度,则需要此方法。一个
get_max_seq_length()
方法,返回模块可以处理的最大序列长度。仅当模块处理输入文本时才需要。
后续模块
管道中的后续模块称为非输入模块。它们负责处理输入模块生成的输入特征并生成最终的句子嵌入。非输入模块可以是任何实现了 Module
类的模块。
它有两个你需要实现抽象方法
一个
forward()
方法,该方法接受一个features
字典,其中包含input_ids
、attention_mask
、token_type_ids
、token_embeddings
和sentence_embedding
等键,具体取决于模块在模型管道中的位置。一个
save()
方法,用于将模块的配置和可选的权重保存到指定目录。
另外,您还可以实现以下方法
一个
load()
静态方法,该方法接受model_name_or_path
参数、用于从 Hugging Face 加载的关键字参数(subfolder
、token
、cache_folder
等)和模块 kwargs(model_kwargs
、trust_remote_code
、backend
等),并根据该目录或模型名称中的模块配置初始化模块。一个
get_sentence_embedding_dimension()
方法,返回模块生成的句子嵌入的维度。如果模块生成嵌入或更新嵌入的维度,则需要此方法。
示例模块
例如,我们可以通过实现一个自定义模块来创建自定义池化方法。
# decay_pooling.py
import torch
from sentence_transformers.models import Module
class DecayMeanPooling(Module):
config_keys: list[str] = ["dimension", "decay"]
def __init__(self, dimension: int, decay: float = 0.95, **kwargs) -> None:
super(DecayMeanPooling, self).__init__()
self.dimension = dimension
self.decay = decay
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
# This module is expected to be used after some modules that provide "token_embeddings"
# and "attention_mask" in the features dictionary.
token_embeddings = features["token_embeddings"]
attention_mask = features["attention_mask"].unsqueeze(-1)
# Apply the attention mask to filter away padding tokens
token_embeddings = token_embeddings * attention_mask
# Calculate mean of token embeddings
sentence_embeddings = token_embeddings.sum(1) / attention_mask.sum(1)
# Apply exponential decay
importance_per_dim = self.decay ** torch.arange(
sentence_embeddings.size(1), device=sentence_embeddings.device
)
features["sentence_embedding"] = sentence_embeddings * importance_per_dim
return features
def get_sentence_embedding_dimension(self) -> int:
return self.dimension
def save(self, output_path, *args, safe_serialization=True, **kwargs) -> None:
self.save_config(output_path)
# The `load` method by default loads the config.json file from the model directory
# and initializes the class with the loaded parameters, i.e. the `config_keys`.
# This works for us, so no need to override it.
注意
建议在 __init__
、forward
、save
、load
和 tokenize
方法中添加 **kwargs
,以确保这些方法与 Sentence Transformers 库未来的更新保持兼容。
现在可以将其用作 Sentence Transformer 模型中的模块
from sentence_transformers import models, SentenceTransformer
from decay_pooling import DecayMeanPooling
transformer = models.Transformer("bert-base-uncased", max_seq_length=256)
decay_mean_pooling = DecayMeanPooling(transformer.get_word_embedding_dimension(), decay=0.99)
normalize = models.Normalize()
model = SentenceTransformer(modules=[transformer, decay_mean_pooling, normalize])
print(model)
"""
SentenceTransformer(
(0): Transformer({'max_seq_length': 256, 'do_lower_case': False, 'architecture': 'BertModel'})
(1): DecayMeanPooling()
(2): Normalize()
)
"""
texts = [
"Hello, World!",
"The quick brown fox jumps over the lazy dog.",
"I am a sentence that is used for testing purposes.",
"This is a test sentence.",
"This is another test sentence.",
]
embeddings = model.encode(texts)
print(embeddings.shape)
# [5, 768]
您可以使用 SentenceTransformer.save_pretrained
保存此模型,这将生成一个 modules.json
文件,内容如下:
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "sentence_transformers.models.Transformer"
},
{
"idx": 1,
"name": "1",
"path": "1_DecayMeanPooling",
"type": "decay_pooling.DecayMeanPooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
为了确保能够导入 decay_pooling.DecayMeanPooling
,您应该将 decay_pooling.py
文件复制到保存模型的目录中。如果您将模型推送到 Hugging Face Hub,那么您也应该将 decay_pooling.py
文件上传到模型的仓库中。然后,每个人都可以通过调用 SentenceTransformer("your-username/your-model-id", trust_remote_code=True)
来使用您的自定义模块。
注意
使用存储在 Hugging Face Hub 上的远程代码的自定义模块要求您的用户在加载模型时将 trust_remote_code
指定为 True
。这是一项安全措施,旨在防止远程代码执行攻击。
如果您的模型和自定义建模代码位于 Hugging Face Hub 上,那么将自定义模块分离到单独的仓库中可能更有意义。这样,您只需维护一个自定义模块的实现,并且可以在多个模型中重复使用它。您可以通过更新 modules.json
文件中的 type
来实现这一点,使其包含自定义模块存储的仓库路径,例如 {repository_id}--{dot_path_to_module}
。例如,如果 decay_pooling.py
文件存储在名为 my-user/my-model-implementation
的仓库中,并且模块名为 DecayMeanPooling
,那么 modules.json
文件可能如下所示:
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "sentence_transformers.models.Transformer"
},
{
"idx": 1,
"name": "1",
"path": "1_DecayMeanPooling",
"type": "my-user/my-model-implementation--decay_pooling.DecayMeanPooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
高级:自定义模块中的关键字参数直传
如果您希望您的用户能够通过 SentenceTransformer.encode
方法指定自定义关键字参数,那么您可以将它们的名称添加到 modules.json
文件中。例如,如果您的模块在用户指定 task
关键字参数时应表现不同,那么您的 modules.json
可能会如下所示:
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "custom_transformer.CustomTransformer",
"kwargs": ["task"]
},
{
"idx": 1,
"name": "1",
"path": "1_Pooling",
"type": "sentence_transformers.models.Pooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Normalize",
"type": "sentence_transformers.models.Normalize"
}
]
然后,您可以在自定义模块的 forward
方法中访问 task
关键字参数
from sentence_transformers.models import Transformer
class CustomTransformer(Transformer):
def forward(self, features: dict[str, torch.Tensor], task: Optional[str] = None, **kwargs) -> dict[str, torch.Tensor]:
if task == "default":
# Do something
else:
# Do something else
return features
这样,用户在调用 SentenceTransformer.encode
时,可以指定 task
关键字参数。
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("your-username/your-model-id", trust_remote_code=True)
texts = [...]
model.encode(texts, task="default")