util
sentence_transformers.util
定义了不同的辅助函数来处理文本嵌入。
辅助函数
- sentence_transformers.util.community_detection(embeddings: Tensor | ndarray, threshold: float = 0.75, min_community_size: int = 10, batch_size: int = 1024, show_progress_bar: bool = False) list[list[int]] [源代码]
用于快速社区检测的函数。
在嵌入中查找所有社区,即彼此接近(比阈值更近)的嵌入。仅返回大于 min_community_size 的社区。社区按降序返回。每个列表中的第一个元素是社区的中心点。
- 参数:
embeddings (torch.Tensor 或 numpy.ndarray) – 输入嵌入。
threshold (float) – 确定两个嵌入是否接近的阈值。默认为 0.75。
min_community_size (int) – 要考虑的社区的最小大小。默认为 10。
batch_size (int) – 用于计算余弦相似度分数的批大小。默认为 1024。
show_progress_bar (bool) – 是否在计算期间显示进度条。默认为 False。
- 返回:
社区列表,其中每个社区表示为索引列表。
- 返回类型:
List[List[int]]
- sentence_transformers.util.http_get(url: str, path: str) None [源代码]
将 URL 下载到磁盘上的给定路径。
- 参数:
url (str) – 要下载的 URL。
path (str) – 保存下载文件的路径。
- 引发:
requests.HTTPError – 如果 HTTP 请求返回非 200 状态代码。
- 返回:
None
- sentence_transformers.util.is_training_available() bool [源代码]
如果我们拥有训练 Sentence Transformers 模型所需的依赖项,即 Huggingface datasets 和 Huggingface accelerate,则返回 True。
- sentence_transformers.util.mine_hard_negatives(dataset: Dataset, model: SentenceTransformer, anchor_column_name: str | None = None, positive_column_name: str | None = None, corpus: list[str] | None = None, cross_encoder: CrossEncoder | None = None, range_min: int = 0, range_max: int | None = None, max_score: float | None = None, min_score: float | None = None, absolute_margin: float | None = None, relative_margin: float | None = None, num_negatives: int = 3, sampling_strategy: Literal['random', 'top'] = 'top', include_positives: bool = False, output_format: Literal['triplet', 'n-tuple', 'labeled-pair', 'labeled-list'] = 'triplet', batch_size: int = 32, faiss_batch_size: int = 16384, use_faiss: bool = False, use_multi_process: list[str] | bool = False, verbose: bool = True, as_triplets: bool | None = None, margin: float | None = None) Dataset [源代码]
将难负例添加到 (anchor, positive) 对的数据集中,以创建 (anchor, positive, negative) 三元组或 (anchor, positive, negative_1, …, negative_n) 元组。
难负例挖掘是一种通过添加难负例来提高数据集质量的技术,难负例是指可能看起来与 anchor 相似但不相同的文本。使用难负例可以提高在数据集上训练的模型的性能。
此函数使用 SentenceTransformer 模型来嵌入数据集中的句子,然后在数据集中找到与每个 anchor 句子最接近的匹配项。然后,它从最接近的匹配项中采样负例,可以选择使用 CrossEncoder 模型重新评分候选对象。
您可以通过多种方式影响候选负例选择
range_min: 要考虑作为负例的最接近匹配项的最小排名:用于跳过最相似的文本,以避免将实际上是正例的文本标记为负例。
range_max: 要考虑作为负例的最接近匹配项的最大排名:用于限制从中采样负例的候选对象数量。较低的值会使处理速度更快,但可能导致较少的候选负例满足边距或 max_score 条件。
max_score: 要考虑作为负例的最大分数:用于跳过与 anchor 太相似的候选对象。
min_score: 要考虑作为负例的最小分数:用于跳过与 anchor 太不相似的候选对象。
absolute_margin: 难负例挖掘的绝对边距:用于跳过与 anchor 的相似度在正例对的特定边距内的候选负例。值 0 可用于强制负例始终比正例更远离 anchor。
relative_margin: 难负例挖掘的相对边距:用于跳过与 anchor 的相似度在正例对的特定边距内的候选负例。值 0.05 表示负例与 anchor 的相似度最多为正例的 95%。
sampling_strategy: 负例的采样策略:“top” 或 “random”。“top” 将始终采样前 n 个候选对象作为负例,而 “random” 将从满足边距或 max_score 条件的候选对象中随机采样 n 个负例。
提示
优秀的 NV-Retriever 论文是理解难负例挖掘的细节以及如何有效使用它的绝佳资源。值得注意的是,它使用以下设置达到了最强的性能
dataset = mine_hard_negatives( dataset=dataset, model=model, relative_margin=0.05, # 0.05 means that the negative is at most 95% as similar to the anchor as the positive num_negatives=num_negatives, # 10 or less is recommended sampling_strategy="top", # "top" means that we sample the top candidates as negatives batch_size=batch_size, # Adjust as needed use_faiss=True, # Optional: Use faiss/faiss-gpu for faster similarity search )
这与 TopK-PercPos (95%) 挖掘方法相对应。
示例
>>> from sentence_transformers.util import mine_hard_negatives >>> from sentence_transformers import SentenceTransformer >>> from datasets import load_dataset >>> # Load a Sentence Transformer model >>> model = SentenceTransformer("all-MiniLM-L6-v2") >>> >>> # Load a dataset to mine hard negatives from >>> dataset = load_dataset("sentence-transformers/natural-questions", split="train") >>> dataset Dataset({ features: ['query', 'answer'], num_rows: 100231 }) >>> dataset = mine_hard_negatives( ... dataset=dataset, ... model=model, ... range_min=10, ... range_max=50, ... max_score=0.8, ... relative_margin=0.05, ... num_negatives=5, ... sampling_strategy="random", ... batch_size=128, ... use_faiss=True, ... ) Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 588/588 [00:32<00:00, 18.07it/s] Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 784/784 [00:08<00:00, 96.41it/s] Querying FAISS index: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:06<00:00, 1.06it/s] Metric Positive Negative Difference Count 100,231 487,865 Mean 0.6866 0.4194 0.2752 Median 0.7010 0.4102 0.2760 Std 0.1125 0.0719 0.1136 Min 0.0303 0.1702 0.0209 25% 0.6221 0.3672 0.1899 50% 0.7010 0.4102 0.2760 75% 0.7667 0.4647 0.3590 Max 0.9584 0.7621 0.7073 Skipped 427,503 potential negatives (8.36%) due to the relative_margin of 0.05. Skipped 978 potential negatives (0.02%) due to the max_score of 0.8. Could not find enough negatives for 13290 samples (2.65%). Consider adjusting the range_max, range_min, relative_margin and max_score parameters if you'd like to find more valid negatives. >>> dataset Dataset({ features: ['query', 'answer', 'negative'], num_rows: 487865 }) >>> dataset[0] { 'query': 'when did richmond last play in a preliminary final', 'answer': "Richmond Football Club Richmond began 2017 with 5 straight wins, a feat it had not achieved since 1995. A series of close losses hampered the Tigers throughout the middle of the season, including a 5-point loss to the Western Bulldogs, 2-point loss to Fremantle, and a 3-point loss to the Giants. Richmond ended the season strongly with convincing victories over Fremantle and St Kilda in the final two rounds, elevating the club to 3rd on the ladder. Richmond's first final of the season against the Cats at the MCG attracted a record qualifying final crowd of 95,028; the Tigers won by 51 points. Having advanced to the first preliminary finals for the first time since 2001, Richmond defeated Greater Western Sydney by 36 points in front of a crowd of 94,258 to progress to the Grand Final against Adelaide, their first Grand Final appearance since 1982. The attendance was 100,021, the largest crowd to a grand final since 1986. The Crows led at quarter time and led by as many as 13, but the Tigers took over the game as it progressed and scored seven straight goals at one point. They eventually would win by 48 points – 16.12 (108) to Adelaide's 8.12 (60) – to end their 37-year flag drought.[22] Dustin Martin also became the first player to win a Premiership medal, the Brownlow Medal and the Norm Smith Medal in the same season, while Damien Hardwick was named AFL Coaches Association Coach of the Year. Richmond's jump from 13th to premiers also marked the biggest jump from one AFL season to the next.", 'negative': "2018 NRL Grand Final The 2018 NRL Grand Final was the conclusive and premiership-deciding game of the 2018 National Rugby League season and was played on Sunday September 30 at Sydney's ANZ Stadium.[1] The match was contested between minor premiers the Sydney Roosters and defending premiers the Melbourne Storm. In front of a crowd of 82,688, Sydney won the match 21–6 to claim their 14th premiership title and their first since 2013. Roosters five-eighth Luke Keary was awarded the Clive Churchill Medal as the game's official man of the match." } >>> dataset.push_to_hub("natural-questions-hard-negatives", "triplet-all")
- 参数:
dataset (Dataset) – 包含 (anchor, positive) 对的数据集。
model (SentenceTransformer) – 用于嵌入句子的 SentenceTransformer 模型。
anchor_column_name (str, optional) – dataset 中包含 anchor/query 的列名。默认为 None,在这种情况下,将使用 dataset 中的第一列。
positive_column_name (str, optional) – dataset 中包含正例候选对象的列名。默认为 None,在这种情况下,将专门使用 dataset 中的第二列作为正例候选对象。
corpus (List[str], optional) – 包含文档字符串的列表,除了 dataset 中的第二列之外,还将用作候选负例。默认为 None,在这种情况下,dataset 中的第二列将专门用作负例候选语料库。
cross_encoder (CrossEncoder, optional) – 用于重新评分候选对象的 CrossEncoder 模型。默认为 None。
range_min (int) – 要考虑作为负例的最接近匹配项的最小排名。默认为 0。
range_max (int, optional) – 要考虑作为负例的最接近匹配项的最大排名。默认为 None。
max_score (float, optional) – 要考虑作为负例的最大分数。默认为 None。
min_score (float, optional) – 要考虑作为负例的最小分数。默认为 None。
absolute_margin (float, optional) – 难负例挖掘的绝对边距,即正例相似度和负例相似度之间的最小距离。默认为 None。
relative_margin (float, optional) – 难负例挖掘的相对边距,即正例相似度和负例相似度之间的最大比率。值 0.05 表示负例与 anchor 的相似度最多为正例的 95%。默认为 None。
num_negatives (int) – 要采样的负例数量。默认为 3。
sampling_strategy (Literal["random", "top"]) – 负例的采样策略:“top” 或 “random”。默认为 “top”。
include_positives (bool) – 是否将正例包含在负例候选中。将此设置为 True 主要用于为 CrossEncoder 模型创建 Reranking 评估数据集,在其中,从第一阶段检索模型获得完整排名(包括正例)可能很有用。默认为 False。
output_format (Literal["triplet", "n-tuple", "labeled-pair", "labeled-list"]) –
用于 datasets.Dataset 的输出格式。选项包括
”triplet”:(anchor,positive,negative)三元组,即 3 列。例如,适用于
CachedMultipleNegativesRankingLoss
。”n-tuple”:(anchor,positive,negative_1,…,negative_n)元组,即 2 + num_negatives 列。例如,适用于
CachedMultipleNegativesRankingLoss
。”labeled-pair”:(anchor,passage,label)文本元组,负例标签为 0,正例标签为 1,即 3 列。例如,适用于
BinaryCrossEntropyLoss
。”labeled-list”:(anchor,[doc1,doc2,…,docN],[label1,label2,…,labelN])三元组,负例标签为 0,正例标签为 1,即 3 列。例如,适用于
LambdaLoss
。
默认为 “triplet”。
batch_size (int) – 用于编码数据集的批大小。默认为 32。
faiss_batch_size (int) – 用于 FAISS top-k 搜索的批大小。默认为 16384。
use_faiss (bool) – 是否使用 FAISS 进行相似度搜索。对于大型数据集,建议使用。默认为 False。
use_multi_process (bool | List[str], optional) – 是否使用多 GPU/CPU 处理。如果为 True,则在使用 CUDA 时使用所有 GPU,在不可用时使用 4 个 CPU 进程。您还可以传递 PyTorch 设备列表,如 [“cuda:0”, “cuda:1”, …] 或 [“cpu”, “cpu”, “cpu”, “cpu”]。
verbose (bool) – 是否打印统计信息和日志记录。默认为 True。
as_triplets (bool, optional) – 已弃用。请改用 output_format。默认为 None。
margin (float, optional) – 已弃用。请使用 absolute_margin 或 relative_margin 代替。默认为 None。
- 返回:
一个包含 (anchor, positive, negative) 三元组,带有标签的 (anchor, passage, label) 文本元组,或 (anchor, positive, negative_1, …, negative_n) 元组的数据集。
- 返回类型:
数据集
- sentence_transformers.util.normalize_embeddings(embeddings: Tensor) Tensor [source]
归一化嵌入矩阵,使每个句子嵌入都具有单位长度。
- 参数:
embeddings (Tensor) – 输入嵌入矩阵。
- 返回:
归一化后的嵌入矩阵。
- 返回类型:
Tensor
- sentence_transformers.util.paraphrase_mining(model, sentences: list[str], show_progress_bar: bool = False, batch_size: int = 32, query_chunk_size: int = 5000, corpus_chunk_size: int = 100000, max_pairs: int = 500000, top_k: int = 100, score_function: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function cos_sim>) list[list[float | int]] [source]
给定一个句子/文本列表,此函数执行释义挖掘。它将所有句子与所有其他句子进行比较,并返回余弦相似度得分最高的句子对列表。
- 参数:
model (SentenceTransformer) – 用于嵌入计算的 SentenceTransformer 模型
sentences (List[str]) – 字符串(文本或句子)列表
show_progress_bar (bool, optional) – 绘制进度条。默认为 False。
batch_size (int, optional) – 模型同时编码的文本数量。默认为 32。
query_chunk_size (int, optional) – 同时搜索 #query_chunk_size 个最相似的对。减小此值可降低内存占用(增加运行时间)。默认为 5000。
corpus_chunk_size (int, optional) – 同时将一个句子与 #corpus_chunk_size 个其他句子进行比较。减小此值可降低内存占用(增加运行时间)。默认为 100000。
max_pairs (int, optional) – 返回的最大文本对数。默认为 500000。
top_k (int, optional) – 对于每个句子,我们最多检索 top_k 个其他句子。默认为 100。
score_function (Callable[[Tensor, Tensor], Tensor], optional) – 用于计算分数的函数。默认情况下,为余弦相似度。默认为 cos_sim。
- 返回:
返回格式为 [score, id1, id2] 的三元组列表
- 返回类型:
List[List[Union[float, int]]]
- sentence_transformers.util.semantic_search(query_embeddings: ~torch.Tensor, corpus_embeddings: ~torch.Tensor, query_chunk_size: int = 100, corpus_chunk_size: int = 500000, top_k: int = 10, score_function: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function cos_sim>) list[list[dict[str, int | float]]] [source]
此函数在查询嵌入列表和语料库嵌入列表之间执行余弦相似度搜索。它可用于信息检索/语义搜索,适用于最多约 100 万条目的语料库。
- 参数:
query_embeddings (
Tensor
) – 包含查询嵌入的二维张量。corpus_embeddings (
Tensor
) – 包含语料库嵌入的二维张量。query_chunk_size (int, optional) – 同时处理 100 个查询。增加此值会提高速度,但需要更多内存。默认为 100。
corpus_chunk_size (int, optional) – 每次扫描语料库 10 万条目。增加此值会提高速度,但需要更多内存。默认为 500000。
top_k (int, optional) – 检索前 k 个匹配条目。默认为 10。
score_function (Callable[[
Tensor
,Tensor
],Tensor
], optional) – 用于计算分数的函数。默认情况下,为余弦相似度。
- 返回:
一个列表,每个查询包含一个条目。每个条目都是一个字典列表,其中包含键 ‘corpus_id’ 和 ‘score’,并按余弦相似度得分降序排序。
- 返回类型:
List[List[Dict[str, Union[int, float]]]]
- sentence_transformers.util.truncate_embeddings(embeddings: ndarray, truncate_dim: int | None) ndarray [source]
- sentence_transformers.util.truncate_embeddings(embeddings: Tensor, truncate_dim: int | None) Tensor
截断嵌入矩阵。
- 参数:
embeddings (Union[np.ndarray, torch.Tensor]) – 要截断的嵌入。
truncate_dim (Optional[int]) – 将句子嵌入截断到的维度。None 表示不截断。
示例
>>> from sentence_transformers import SentenceTransformer >>> from sentence_transformers.util import truncate_embeddings >>> model = SentenceTransformer("tomaarsen/mpnet-base-nli-matryoshka") >>> embeddings = model.encode(["It's so nice outside!", "Today is a beautiful day.", "He drove to work earlier"]) >>> embeddings.shape (3, 768) >>> model.similarity(embeddings, embeddings) tensor([[1.0000, 0.8100, 0.1426], [0.8100, 1.0000, 0.2121], [0.1426, 0.2121, 1.0000]]) >>> truncated_embeddings = truncate_embeddings(embeddings, 128) >>> truncated_embeddings.shape >>> model.similarity(truncated_embeddings, truncated_embeddings) tensor([[1.0000, 0.8092, 0.1987], [0.8092, 1.0000, 0.2716], [0.1987, 0.2716, 1.0000]])
- 返回:
截断后的嵌入。
- 返回类型:
Union[np.ndarray, torch.Tensor]
模型优化
- sentence_transformers.backend.export_dynamic_quantized_onnx_model(model: SentenceTransformer | CrossEncoder, quantization_config: QuantizationConfig | Literal['arm64', 'avx2', 'avx512', 'avx512_vnni'], model_name_or_path: str, push_to_hub: bool = False, create_pr: bool = False, file_suffix: str | None = None) None [source]
从 SentenceTransformer 或 CrossEncoder 模型导出量化的 ONNX 模型。
此函数应用动态量化,即无需校准数据集。每个默认量化配置都将模型量化为 int8,从而可以在 CPU 上实现更快的推理,但在 GPU 上可能会更慢。
有关更多信息和基准测试,请参见以下页面
- 参数:
model (SentenceTransformer | CrossEncoder) – 要量化的 SentenceTransformer 或 CrossEncoder 模型。必须使用 backend=”onnx” 加载。
quantization_config (QuantizationConfig) – 量化配置。
model_name_or_path (str) – 量化模型将保存的路径或 Hugging Face Hub 仓库名称。
push_to_hub (bool, optional) – 是否将量化模型推送到 Hugging Face Hub。默认为 False。
create_pr (bool, optional) – 推送到 Hugging Face Hub 时是否创建拉取请求。默认为 False。
file_suffix (str | None, optional) – 要添加到量化模型文件名中的后缀。默认为 None。
- 引发:
ImportError – 如果未安装所需的软件包 optimum 和 onnxruntime。
ValueError – 如果提供的模型不是使用 backend=”onnx” 加载的有效 SentenceTransformer 或 CrossEncoder 模型。
ValueError – 如果提供的 quantization_config 无效。
- 返回:
None
- sentence_transformers.backend.export_optimized_onnx_model(model: SentenceTransformer | CrossEncoder, optimization_config: OptimizationConfig | Literal['O1', 'O2', 'O3', 'O4'], model_name_or_path: str, push_to_hub: bool = False, create_pr: bool = False, file_suffix: str | None = None) None [source]
从 SentenceTransformer 或 CrossEncoder 模型导出优化的 ONNX 模型。
O1-O4 优化级别由 Optimum 定义,并在此处记录: https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/optimization
优化级别为
O1:基本通用优化。
O2:基本和扩展的通用优化,特定于 transformers 的融合。
O3:与 O2 相同,但使用 GELU 近似。
O4:与 O3 相同,但使用混合精度(fp16,仅限 GPU)
有关更多信息和基准测试,请参见以下页面
- 参数:
model (SentenceTransformer | CrossEncoder) – 要优化的 SentenceTransformer 或 CrossEncoder 模型。必须使用 backend=”onnx” 加载。
optimization_config (OptimizationConfig | Literal["O1", "O2", "O3", "O4"]) – 优化配置或级别。
model_name_or_path (str) – 优化模型将保存的路径或 Hugging Face Hub 仓库名称。
push_to_hub (bool, optional) – 是否将优化模型推送到 Hugging Face Hub。默认为 False。
create_pr (bool, optional) – 推送到 Hugging Face Hub 时是否创建拉取请求。默认为 False。
file_suffix (str | None, optional) – 要添加到优化模型文件名中的后缀。默认为 None。
- 引发:
ImportError – 如果未安装所需的软件包 optimum 和 onnxruntime。
ValueError – 如果提供的模型不是使用 backend=”onnx” 加载的有效 SentenceTransformer 或 CrossEncoder 模型。
ValueError – 如果提供的 optimization_config 无效。
- 返回:
None
- sentence_transformers.backend.export_static_quantized_openvino_model(model: SentenceTransformer | CrossEncoder, quantization_config: OVQuantizationConfig | dict | None, model_name_or_path: str, dataset_name: str | None = None, dataset_config_name: str | None = None, dataset_split: str | None = None, column_name: str | None = None, push_to_hub: bool = False, create_pr: bool = False, file_suffix: str = 'qint8_quantized') None [source]
从 SentenceTransformer 或 CrossEncoder 模型导出量化的 OpenVINO 模型。
此函数使用校准数据集应用训练后静态量化 (PTQ),校准数据集校准量化常数,而无需重新训练模型。每个默认量化配置都将模型转换为 int8 精度,从而在保持精度的同时实现更快的推理。
有关更多信息和基准测试,请参见以下页面
- 参数:
model (SentenceTransformer | CrossEncoder) – 要量化的 SentenceTransformer 或 CrossEncoder 模型。必须使用 backend=”openvino” 加载。
quantization_config (OVQuantizationConfig | dict | None) – 量化配置。如果为 None,则使用默认值。
model_name_or_path (str) – 量化模型将保存的路径或 Hugging Face Hub 仓库名称。
dataset_name (str, optional) – 要加载以进行校准的数据集名称。如果未指定,默认将使用 glue 数据集的 sst2 子集。
dataset_config_name (str, optional) – 要加载的数据集的特定配置。
dataset_split (str, optional) – 要加载的数据集拆分(例如,“train”、“test”)。默认为 None。
column_name (str, optional) – 数据集中用于校准的列名。默认为 None。
push_to_hub (bool, optional) – 是否将量化模型推送到 Hugging Face Hub。默认为 False。
create_pr (bool, optional) – 推送到 Hugging Face Hub 时是否创建拉取请求。默认为 False。
file_suffix (str, optional) – 要添加到量化模型文件名中的后缀。默认为 qint8_quantized。
- 引发:
ImportError – 如果未安装所需的软件包 optimum 和 openvino。
ValueError – 如果提供的模型不是使用 backend=”openvino” 加载的有效 SentenceTransformer 或 CrossEncoder 模型。
ValueError – 如果提供的 quantization_config 无效。
- 返回:
None
相似性度量
- sentence_transformers.util.cos_sim(a: list | ndarray | Tensor, b: list | ndarray | Tensor) Tensor [source]
计算两个张量之间的余弦相似度。
- 参数:
a (Union[list, np.ndarray, Tensor]) – 第一个张量。
b (Union[list, np.ndarray, Tensor]) – 第二个张量。
- 返回:
矩阵,其中 res[i][j] = cos_sim(a[i], b[j])
- 返回类型:
Tensor
- sentence_transformers.util.dot_score(a: list | ndarray | Tensor, b: list | ndarray | Tensor) Tensor [source]
计算所有 i 和 j 的点积 dot_prod(a[i], b[j])。
- 参数:
a (Union[list, np.ndarray, Tensor]) – 第一个张量。
b (Union[list, np.ndarray, Tensor]) – 第二个张量。
- 返回:
矩阵,其中 res[i][j] = dot_prod(a[i], b[j])
- 返回类型:
Tensor
- sentence_transformers.util.euclidean_sim(a: list | ndarray | Tensor, b: list | ndarray | Tensor) Tensor [source]
计算两个张量之间的欧几里得相似度(即,负距离)。
- 参数:
a (Union[list, np.ndarray, Tensor]) – 第一个张量。
b (Union[list, np.ndarray, Tensor]) – 第二个张量。
- 返回:
矩阵,其中 res[i][j] = -euclidean_distance(a[i], b[j])
- 返回类型:
Tensor
- sentence_transformers.util.manhattan_sim(a: list | ndarray | Tensor, b: list | ndarray | Tensor) Tensor [source]
计算两个张量之间的曼哈顿相似度(即,负距离)。
- 参数:
a (Union[list, np.ndarray, Tensor]) – 第一个张量。
b (Union[list, np.ndarray, Tensor]) – 第二个张量。
- 返回:
矩阵,其中 res[i][j] = -manhattan_distance(a[i], b[j])
- 返回类型:
Tensor
- sentence_transformers.util.pairwise_cos_sim(a: Tensor, b: Tensor) Tensor [source]
计算成对余弦相似度 cos_sim(a[i], b[i])。
- 参数:
a (Union[list, np.ndarray, Tensor]) – 第一个张量。
b (Union[list, np.ndarray, Tensor]) – 第二个张量。
- 返回:
向量,其中 res[i] = cos_sim(a[i], b[i])
- 返回类型:
Tensor
- sentence_transformers.util.pairwise_dot_score(a: Tensor, b: Tensor) Tensor [source]
计算成对点积 dot_prod(a[i], b[i])。
- 参数:
a (Union[list, np.ndarray, Tensor]) – 第一个张量。
b (Union[list, np.ndarray, Tensor]) – 第二个张量。
- 返回:
向量 res[i] = dot_prod(a[i], b[i])
- 返回类型:
Tensor