Samplers
BatchSamplers
- class sentence_transformers.training_args.BatchSamplers(value)[source]
存储批采样器的可接受字符串标识符。
批采样器负责确定在训练期间如何将样本分组到批次中。有效选项包括
BatchSamplers.BATCH_SAMPLER
: [默认] 使用DefaultBatchSampler
,默认的 PyTorch 批采样器。BatchSamplers.NO_DUPLICATES
: 使用NoDuplicatesBatchSampler
,确保批次中没有重复样本。推荐用于使用批内负样本的损失函数,例如BatchSamplers.GROUP_BY_LABEL
: 使用GroupByLabelBatchSampler
,确保每个批次至少有 2 个来自相同标签的样本。推荐用于需要来自相同标签的多个样本的损失函数,例如
如果要使用自定义批采样器,可以创建一个新的 Trainer 类,该类继承自
SentenceTransformerTrainer
并覆盖get_batch_sampler()
方法。该方法必须返回一个类实例,该实例支持__iter__
和__len__
方法。前者应为每个批次生成索引列表,后者应返回批次数量。- 用法
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments from sentence_transformers.training_args import BatchSamplers from sentence_transformers.losses import MultipleNegativesRankingLoss from datasets import Dataset model = SentenceTransformer("microsoft/mpnet-base") train_dataset = Dataset.from_dict({ "anchor": ["It's nice weather outside today.", "He drove to work."], "positive": ["It's so sunny.", "He took the car to the office."], }) loss = MultipleNegativesRankingLoss(model) args = SentenceTransformerTrainingArguments( output_dir="checkpoints", batch_sampler=BatchSamplers.NO_DUPLICATES, ) trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, ) trainer.train()
- class sentence_transformers.sampler.DefaultBatchSampler(*args, **kwargs)[source]
此采样器是 SentenceTransformer 库中使用的默认批采样器。它等效于 PyTorch BatchSampler。
- 参数:
sampler (Sampler 或 Iterable) – 用于从数据集中采样元素的采样器,例如 SubsetRandomSampler。
batch_size (int) – 每批样本数。
drop_last (bool) – 如果为 True,则当数据集大小不能被批大小整除时,丢弃最后一个不完整的批次。
- class sentence_transformers.sampler.NoDuplicatesBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] = [], generator: Generator | None = None, seed: int = 0)[source]
此采样器创建批次,使得每个批次包含值唯一的样本,即使跨列也是如此。当损失函数将批次中的其他样本视为批内负样本,并且您想要确保负样本不是锚/正样本的重复项时,这非常有用。
- 推荐用于
- 参数:
dataset (Dataset) – 要从中采样的数据集。
batch_size (int) – 每批样本数。
drop_last (bool) – 如果为 True,则当数据集大小不能被批大小整除时,丢弃最后一个不完整的批次。
valid_label_columns (List[str]) – 要检查标签的列名列表。在数据集中找到的来自
valid_label_columns
的第一个列名将用作标签列。generator (torch.Generator, optional) – 用于洗牌索引的可选随机数生成器。
seed (int, optional) – 随机数生成器的种子,以确保可重复性。
- class sentence_transformers.sampler.GroupByLabelBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0)[source]
此采样器按标签对样本进行分组,并旨在创建批次,使得每个批次都包含标签尽可能同质的样本。此采样器旨在与
Batch...TripletLoss
类一起使用,这些类要求每个批次包含每个标签类别至少 2 个示例。- 推荐用于
- 参数:
dataset (Dataset) – 要从中采样的数据集。
batch_size (int) – 每批样本数。必须能被 2 整除。
drop_last (bool) – 如果为 True,则当数据集大小不能被批大小整除时,丢弃最后一个不完整的批次。
valid_label_columns (List[str]) – 要检查标签的列名列表。在数据集中找到的来自
valid_label_columns
的第一个列名将用作标签列。generator (torch.Generator, optional) – 用于洗牌索引的可选随机数生成器。
seed (int, optional) – 随机数生成器的种子,以确保可重复性。
MultiDatasetBatchSamplers
- class sentence_transformers.training_args.MultiDatasetBatchSamplers(value)[source]
存储多数据集批采样器的可接受字符串标识符。
多数据集批采样器负责确定在训练期间以何种顺序从多个数据集中采样批次。有效选项包括
MultiDatasetBatchSamplers.ROUND_ROBIN
: 使用RoundRobinBatchSampler
,它使用从每个数据集轮询采样直到其中一个耗尽为止。使用此策略,可能并非使用每个数据集中的所有样本,但可以确保每个数据集都被平等地采样。MultiDatasetBatchSamplers.PROPORTIONAL
: [默认] 使用ProportionalBatchSampler
,它根据每个数据集的大小比例进行采样。使用此策略,将使用每个数据集中的所有样本,并且更频繁地从较大的数据集中采样。
- 用法
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments from sentence_transformers.training_args import MultiDatasetBatchSamplers from sentence_transformers.losses import CoSENTLoss from datasets import Dataset, DatasetDict model = SentenceTransformer("microsoft/mpnet-base") train_general = Dataset.from_dict({ "sentence_A": ["It's nice weather outside today.", "He drove to work."], "sentence_B": ["It's so sunny.", "He took the car to the bank."], "score": [0.9, 0.4], }) train_medical = Dataset.from_dict({ "sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."], "sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."], "score": [0.8, 0.6, 0.7], }) train_legal = Dataset.from_dict({ "sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."], "sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."], "score": [0.7, 0.8], }) train_dataset = DatasetDict({ "general": train_general, "medical": train_medical, "legal": train_legal, }) loss = CoSENTLoss(model) args = SentenceTransformerTrainingArguments( output_dir="checkpoints", multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, ) trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, ) trainer.train()
- class sentence_transformers.sampler.RoundRobinBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int | None = None)[source]
批采样器,它以轮询方式从多个批采样器中生成批次,直到其中一个耗尽为止。使用此采样器,不太可能使用每个数据集中的所有样本,但我们确实确保每个数据集都被平等地采样。
- 参数:
dataset (ConcatDataset) – 多个数据集的串联。
batch_samplers (List[BatchSampler]) – 批采样器列表,ConcatDataset 中的每个数据集对应一个。
generator (torch.Generator, optional) – 用于可重复采样的生成器。默认为 None。
seed (int, optional) – 生成器的种子。默认为 None。
- class sentence_transformers.sampler.ProportionalBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator, seed: int)[source]
批采样器,它根据每个数据集的大小比例从每个数据集中采样,直到所有数据集同时耗尽。使用此采样器,将使用每个数据集中的所有样本,并且更频繁地从较大的数据集中采样。
- 参数:
dataset (ConcatDataset) – 多个数据集的串联。
batch_samplers (List[BatchSampler]) – 批采样器列表,ConcatDataset 中的每个数据集对应一个。
generator (torch.Generator, optional) – 用于可重复采样的生成器。默认为 None。
seed (int, optional) – 生成器的种子。默认为 None。