speechbrain.dataio.dataloader 模块

PyTorch 兼容的 DataLoaders

本质上,我们扩展了 PyTorch DataLoader,增加了保存数据加载状态的功能,以便可以在一个 epoch 的中间保存检查点。

示例

>>> import torch
>>> from speechbrain.utils.checkpoints import Checkpointer
>>> # An example "dataset" and its loader
>>> dataset = torch.randn(10, 1)
>>> dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> # Setup the checkpointer:
>>> tmpdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader})
>>> # Iterate:
>>> for i, data_point in enumerate(dataloader):
...     # Here you would process the data:
...     rainfall_amount_prediction = data_point * 4.
...     # Now, imagine the experiment gets killed on the fifth batch:
...     if i == 4:
...         break
...     # Luckily, you had just saved a checkpoint:
...     if i == 3:
...         _ = checkpointer.save_checkpoint(end_of_epoch = False)
>>> # So when you restart the experiment:
>>> new_dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader})
>>> _ = new_checkpointer.recover_if_possible()
>>> # The dataloader fast-forwards to the position where we left off:
>>> assert next(iter(new_dataloader)) == dataset[4]
作者
  • Aku Rouhe 2020

摘要

LoopedLoader

无限期地循环底层可迭代对象,具有名义的 epoch 长度

SaveableDataLoader

一个可保存版本的 PyTorch DataLoader。

函数

distributed_loader_specifics

必要时为 DDP 准备 loader_kwargs。

make_dataloader

使用 SpeechBrain 默认设置创建一个基本的 DataLoader。

参考

speechbrain.dataio.dataloader.distributed_loader_specifics(distributed_launch, rank, dataset, loader_kwargs)[source]

必要时为 DDP 准备 loader_kwargs。

参数:
  • distributed_launch (bool) – DDP 标志

  • rank (int) – DDP 中的节点排名

  • dataset (Dataset) – 用于创建 DataLoader 的数据集。

  • loader_kwargs (dict) – 传递给 DataLoader 的关键字参数,请参阅 PyTorch DataLoader 获取选项。

返回值:

增强的 DataLoader 关键字参数

返回类型:

loader_kwargs

speechbrain.dataio.dataloader.make_dataloader(dataset, looped_nominal_epoch=None, **loader_kwargs)[source]

使用 SpeechBrain 默认设置创建一个基本的 DataLoader。

对于 DynamicItemDatasets (返回字典),使用 PaddedBatch 作为默认的 collate_fn。

洗牌通过 ReproducibleRandomSampler 实现。

如果 Dataset 不是 IterableDataset,则 DataLoader 是 SaveableDataLoader。

如果 Dataset 是 webdataset.dataset.Composable,则设置默认 batch_size = None。

也可以连续循环底层 dataloader,并在名义 epoch 长度处停止迭代。

参数:
  • dataset (Dataset) – 用于创建 DataLoader 的数据集。

  • looped_nominal_epoch (None, int) – 如果给定一个整数,则无限循环底层 DataLoader,并以批次(或 DataLoader 生成的任何内容)为单位设置名义 epoch 长度。

  • **loader_kwargs (dict) – 传递给 DataLoader 的关键字参数,请参阅 PyTorch DataLoader 获取选项。

返回值:

  • DataLoader – 如果 looped_nominal_epoch 为 None

  • LoopedLoader – 如果 looped_nominal_epoch 不为 None

class speechbrain.dataio.dataloader.SaveableDataLoader(*args, **kwargs)[source]

基类: DataLoader

一个可保存版本的 PyTorch DataLoader。

有关用法,请参阅 torch.utils.data.DataLoader。此类的功能应与 PyTorch 基本 DataLoader 完全一致,但可以使用 SpeechBrain 的 Checkpointer 进行检查点保存。

注意

1. 可保存性是通过一些不幸的、略带魔幻的方式实现的。2. 数据加载器在进入 __iter__ 后无法恢复。通常这不是问题,因为恢复应该在训练开始之前发生。然而,在评估之前,通常也会恢复性能最佳的检查点。因此,如果在进入 __iter__ 后加载了检查点,我们只假设是出于这个原因。会记录一个警告,但仅此而已。

class speechbrain.dataio.dataloader.LoopedLoader(loader, epoch_length, batchsize_fn=None)[source]

基类: object

无限期地循环底层可迭代对象,具有名义的 epoch 长度

这对于使用 IterableDatasets,特别是 webdataset 风格的加载非常有用。我们建议在 webdataset IterableDataset 实例上使用 .repeat(),这样底层的 dataloader 就会自然地无限持续下去。

参数:
  • loader (iterable) – 一个 DataLoader 或其他重复循环的可迭代对象。

  • epoch_length (int) – 名义 epoch 的长度。达到这么多步骤后,会引发 StopIteration

  • batchsize_fn (callable) – 用于确定批次大小的函数,默认为 BatchsizeGuesser

save(path)[source]

保存所需信息。

load(path, end_of_epoch=True)[source]

加载所需信息。