speechbrain.dataio.dataloader 模块
PyTorch 兼容的 DataLoaders
本质上,我们扩展了 PyTorch DataLoader,增加了保存数据加载状态的功能,以便可以在一个 epoch 的中间保存检查点。
示例
>>> import torch
>>> from speechbrain.utils.checkpoints import Checkpointer
>>> # An example "dataset" and its loader
>>> dataset = torch.randn(10, 1)
>>> dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> # Setup the checkpointer:
>>> tmpdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader})
>>> # Iterate:
>>> for i, data_point in enumerate(dataloader):
... # Here you would process the data:
... rainfall_amount_prediction = data_point * 4.
... # Now, imagine the experiment gets killed on the fifth batch:
... if i == 4:
... break
... # Luckily, you had just saved a checkpoint:
... if i == 3:
... _ = checkpointer.save_checkpoint(end_of_epoch = False)
>>> # So when you restart the experiment:
>>> new_dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader})
>>> _ = new_checkpointer.recover_if_possible()
>>> # The dataloader fast-forwards to the position where we left off:
>>> assert next(iter(new_dataloader)) == dataset[4]
- 作者
Aku Rouhe 2020
摘要
类
无限期地循环底层可迭代对象,具有名义的 epoch 长度 |
|
一个可保存版本的 PyTorch DataLoader。 |
函数
必要时为 DDP 准备 loader_kwargs。 |
|
使用 SpeechBrain 默认设置创建一个基本的 DataLoader。 |
参考
- speechbrain.dataio.dataloader.distributed_loader_specifics(distributed_launch, rank, dataset, loader_kwargs)[source]
必要时为 DDP 准备 loader_kwargs。
- speechbrain.dataio.dataloader.make_dataloader(dataset, looped_nominal_epoch=None, **loader_kwargs)[source]
使用 SpeechBrain 默认设置创建一个基本的 DataLoader。
对于 DynamicItemDatasets (返回字典),使用 PaddedBatch 作为默认的 collate_fn。
洗牌通过 ReproducibleRandomSampler 实现。
如果 Dataset 不是 IterableDataset,则 DataLoader 是 SaveableDataLoader。
如果 Dataset 是 webdataset.dataset.Composable,则设置默认 batch_size = None。
也可以连续循环底层 dataloader,并在名义 epoch 长度处停止迭代。
- 参数:
- 返回值:
DataLoader – 如果 looped_nominal_epoch 为 None
LoopedLoader – 如果 looped_nominal_epoch 不为 None
- class speechbrain.dataio.dataloader.SaveableDataLoader(*args, **kwargs)[source]
基类:
DataLoader
一个可保存版本的 PyTorch DataLoader。
有关用法,请参阅
torch.utils.data.DataLoader
。此类的功能应与 PyTorch 基本 DataLoader 完全一致,但可以使用 SpeechBrain 的 Checkpointer 进行检查点保存。注意
1. 可保存性是通过一些不幸的、略带魔幻的方式实现的。2. 数据加载器在进入 __iter__ 后无法恢复。通常这不是问题,因为恢复应该在训练开始之前发生。然而,在评估之前,通常也会恢复性能最佳的检查点。因此,如果在进入 __iter__ 后加载了检查点,我们只假设是出于这个原因。会记录一个警告,但仅此而已。
- class speechbrain.dataio.dataloader.LoopedLoader(loader, epoch_length, batchsize_fn=None)[source]
基类:
object
无限期地循环底层可迭代对象,具有名义的 epoch 长度
这对于使用 IterableDatasets,特别是 webdataset 风格的加载非常有用。我们建议在 webdataset IterableDataset 实例上使用
.repeat()
,这样底层的 dataloader 就会自然地无限持续下去。- 参数:
loader (iterable) – 一个 DataLoader 或其他重复循环的可迭代对象。
epoch_length (int) – 名义 epoch 的长度。达到这么多步骤后,会引发 StopIteration
batchsize_fn (callable) – 用于确定批次大小的函数,默认为
BatchsizeGuesser