采样器

批次采样器

sentence_transformers.training_args.BatchSamplers(value)[source]

存储批次采样器的可接受字符串标识符。

批次采样器负责确定训练期间如何将样本分组到批次中。有效选项包括:

如果您想使用自定义批次采样器,可以继承 DefaultBatchSampler 类,并将该类(而非实例)传递给 SentenceTransformerTrainingArguments(或 CrossEncoderTrainingArguments 等)中的 batch_sampler 参数。或者,您可以传递一个函数,该函数接受 datasetbatch_sizedrop_lastvalid_label_columnsgeneratorseed 参数,并返回一个 DefaultBatchSampler 实例。

使用
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.losses import MultipleNegativesRankingLoss
from datasets import Dataset

model = SentenceTransformer("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "anchor": ["It's nice weather outside today.", "He drove to work."],
    "positive": ["It's so sunny.", "He took the car to the office."],
})
loss = MultipleNegativesRankingLoss(model)
args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    batch_sampler=BatchSamplers.NO_DUPLICATES,
)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()
sentence_transformers.sampler.DefaultBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0)[source]

该采样器是 SentenceTransformer 库中使用的默认批次采样器。它等同于 PyTorch 的 BatchSampler。

参数:
  • sampler (SamplerIterable) – 用于从数据集中采样元素的采样器,例如 SubsetRandomSampler。

  • batch_size (int) – 每批次的样本数量。

  • drop_last (bool) – 如果为 True,当数据集大小不能被批次大小整除时,丢弃最后一个不完整的批次。

  • valid_label_columns (List[str], 可选) – 用于检查标签的列名列表。数据集中找到的来自 valid_label_columns 的第一个列名将被用作标签列。

  • generator (torch.Generator, 可选) – 用于打乱索引的可选随机数生成器。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

sentence_transformers.sampler.NoDuplicatesBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0)[source]

该采样器创建批次时,确保每个批次中的样本值是唯一的,即使跨列也是如此。当损失函数将批次中的其他样本视为批内负例,并且您希望确保负例不是锚点/正样本的重复时,这很有用。

推荐用于
参数:
  • dataset (Dataset) – 要从中采样的D数据集。

  • batch_size (int) – 每批次的样本数量。

  • drop_last (bool) – 如果为 True,当数据集大小不能被批次大小整除时,丢弃最后一个不完整的批次。

  • valid_label_columns (List[str], 可选) – 用于检查标签的列名列表。数据集中找到的来自 valid_label_columns 的第一个列名将被用作标签列。

  • generator (torch.Generator, 可选) – 用于打乱索引的可选随机数生成器。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

sentence_transformers.sampler.GroupByLabelBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0)[source]

该采样器根据样本的标签进行分组,旨在创建批次,使每个批次中的标签尽可能同质。此采样器旨在与 Batch...TripletLoss 类一起使用,这些类要求每个批次至少包含每个标签类别的 2 个示例。

推荐用于
参数:
  • dataset (Dataset) – 要从中采样的D数据集。

  • batch_size (int) – 每批次的样本数量。必须能被 2 整除。

  • drop_last (bool) – 如果为 True,当数据集大小不能被批次大小整除时,丢弃最后一个不完整的批次。

  • valid_label_columns (List[str], 可选) – 用于检查标签的列名列表。数据集中找到的来自 valid_label_columns 的第一个列名将被用作标签列。

  • generator (torch.Generator, 可选) – 用于打乱索引的可选随机数生成器。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

多数据集批次采样器

sentence_transformers.training_args.MultiDatasetBatchSamplers(value)[source]

存储多数据集批次采样器的可接受字符串标识符。

多数据集批次采样器负责确定训练期间从多个数据集中采样批次的顺序。有效选项包括:

  • MultiDatasetBatchSamplers.ROUND_ROBIN:使用 RoundRobinBatchSampler,它从每个数据集中进行轮询采样,直到其中一个数据集耗尽。使用此策略,可能不会使用每个数据集中的所有样本,但每个数据集的采样频率是相同的。

  • MultiDatasetBatchSamplers.PROPORTIONAL[默认] 使用 ProportionalBatchSampler,它按数据集大小比例从每个数据集中采样。使用此策略,每个数据集中的所有样本都会被使用,并且较大数据集的采样频率更高。

如果您想使用自定义多数据集批次采样器,可以继承 MultiDatasetDefaultBatchSampler 类,并将该类(而非实例)传递给 SentenceTransformerTrainingArguments 中的 multi_dataset_batch_sampler 参数(或 CrossEncoderTrainingArguments 等)。或者,您可以传递一个函数,该函数接受 dataset(一个 ConcatDataset)、batch_samplers(即 ConcatDataset 中每个数据集的批次采样器列表)、generatorseed 参数,并返回一个 MultiDatasetDefaultBatchSampler 实例。

使用
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.training_args import MultiDatasetBatchSamplers
from sentence_transformers.losses import CoSENTLoss
from datasets import Dataset, DatasetDict

model = SentenceTransformer("microsoft/mpnet-base")
train_general = Dataset.from_dict({
    "sentence_A": ["It's nice weather outside today.", "He drove to work."],
    "sentence_B": ["It's so sunny.", "He took the car to the bank."],
    "score": [0.9, 0.4],
})
train_medical = Dataset.from_dict({
    "sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."],
    "sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."],
    "score": [0.8, 0.6, 0.7],
})
train_legal = Dataset.from_dict({
    "sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."],
    "sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."],
    "score": [0.7, 0.8],
})
train_dataset = DatasetDict({
    "general": train_general,
    "medical": train_medical,
    "legal": train_legal,
})

loss = CoSENTLoss(model)
args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()
sentence_transformers.sampler.MultiDatasetDefaultBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int = 0)[source]

从多个批次采样器生成批次的抽象基础批次采样器。此类别必须被子类化以实现特定的采样策略,不能直接使用。

参数:
  • dataset (ConcatDataset) – 多个数据集的连接。

  • batch_samplers (List[BatchSampler]) – 批次采样器列表,ConcatDataset 中每个数据集对应一个。

  • generator (torch.Generator, 可选) – 用于可复现采样的生成器。默认为 None。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

sentence_transformers.sampler.RoundRobinBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int = 0)[source]

批次采样器,以轮询方式从多个批次采样器中生成批次,直到其中一个耗尽。使用此采样器,不太可能使用每个数据集中的所有样本,但我们确实确保每个数据集的采样频率是相同的。

参数:
  • dataset (ConcatDataset) – 多个数据集的连接。

  • batch_samplers (List[BatchSampler]) – 批次采样器列表,ConcatDataset 中每个数据集对应一个。

  • generator (torch.Generator, 可选) – 用于可复现采样的生成器。默认为 None。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。

sentence_transformers.sampler.ProportionalBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int = 0)[source]

批次采样器,按数据集大小比例从每个数据集中采样,直到所有数据集同时耗尽。使用此采样器,每个数据集中的所有样本都会被使用,并且较大数据集的采样频率更高。

参数:
  • dataset (ConcatDataset) – 多个数据集的连接。

  • batch_samplers (List[BatchSampler]) – 批次采样器列表,ConcatDataset 中每个数据集对应一个。

  • generator (torch.Generator, 可选) – 用于可复现采样的生成器。默认为 None。

  • seed (int) – 随机数生成器的种子,以确保可复现性。默认为 0。