speechbrain.dataio.batch 模块
批次整理
- 作者
Aku Rouhe 2020
摘要
类
尝试找出批次大小,但绝不报错 |
|
当 examples 是字典且具有变长序列时的 Collate_fn。 |
|
PaddedData(data, lengths) |
参考
- class speechbrain.dataio.batch.PaddedData(data, lengths)
基类:
tuple
- data
字段编号 0 的别名
- lengths
字段编号 1 的别名
- class speechbrain.dataio.batch.PaddedBatch(examples, padded_keys=None, device_prep_keys=None, padding_func=<function batch_pad_right>, padding_kwargs={}, apply_default_convert=True, nonpadded_stack=True)[source]
基类:
object
当 examples 是字典且具有变长序列时的 Collate_fn。
示例中的不同元素按键匹配。所有 numpy 张量都会转换为 Torch(PyTorch default_convert)。然后,默认情况下,所有值为 torch.Tensor 的元素都会进行填充并支持集体调用 pin_memory() 和 to()。常规的 Python 数据类型仅收集在一个列表中。
- 参数:
examples (list) – example 字典列表,由 Dataloader 生成。
padded_keys (list, None) – (可选)要填充的键列表。如果为 None,则填充所有 torch.Tensor。
device_prep_keys (list, None) – (可选)只有这些键参与集体内存固定(pinning)和使用 to() 进行移动。如果为 None,则默认为所有值为 torch.Tensor 的项。
padding_func (callable, optional) – 使用要一起填充的张量列表调用。需要返回两个张量:填充后的数据,以及另一个用于数据长度的张量。
padding_kwargs (dict) – (可选)传递给 padding_func 的额外关键字参数。例如 mode, value。
apply_default_convert (bool) – 是否对所有数据应用 PyTorch default_convert(例如,递归地将 numpy 转换为 torch 等)。默认值:True,通常能做正确的事情。
nonpadded_stack (bool) – 是否对未填充的值应用类似 PyTorch-default_collate 的堆叠。如果可以,则进行堆叠,如果不能,则不会报错。默认值:True,通常能做正确的事情。
示例
>>> batch = PaddedBatch([ ... {"id": "ex1", "foo": torch.Tensor([1.])}, ... {"id": "ex2", "foo": torch.Tensor([2., 1.])}]) >>> # Attribute or key-based access: >>> batch.id ['ex1', 'ex2'] >>> batch["id"] ['ex1', 'ex2'] >>> # torch.Tensors get padded >>> type(batch.foo) <class 'speechbrain.dataio.batch.PaddedData'> >>> batch.foo.data tensor([[1., 0.], [2., 1.]]) >>> batch.foo.lengths tensor([0.5000, 1.0000]) >>> # Batch supports collective operations: >>> _ = batch.to(dtype=torch.half) >>> batch.foo.data tensor([[1., 0.], [2., 1.]], dtype=torch.float16) >>> batch.foo.lengths tensor([0.5000, 1.0000], dtype=torch.float16) >>> # Numpy tensors get converted to torch and padded as well: >>> import numpy as np >>> batch = PaddedBatch([ ... {"wav": np.asarray([1,2,3,4])}, ... {"wav": np.asarray([1,2,3])}]) >>> batch.wav # +ELLIPSIS PaddedData(data=tensor([[1, 2,... >>> # Basic stacking collation deals with non padded data: >>> batch = PaddedBatch([ ... {"spk_id": torch.tensor([1]), "wav": torch.tensor([.1,.0,.3])}, ... {"spk_id": torch.tensor([2]), "wav": torch.tensor([.2,.3,-.1])}], ... padded_keys=["wav"]) >>> batch.spk_id tensor([[1], [2]]) >>> # And some data is left alone: >>> batch = PaddedBatch([ ... {"text": ["Hello"]}, ... {"text": ["How", "are", "you?"]}]) >>> batch.text [['Hello'], ['How', 'are', 'you?']]
- __iter__()[source]
遍历批次中的不同元素。
- 返回类型:
批次的迭代器。
示例
>>> batch = PaddedBatch([ ... {"id": "ex1", "val": torch.Tensor([1.])}, ... {"id": "ex2", "val": torch.Tensor([2., 1.])}]) >>> ids, vals = batch >>> ids ['ex1', 'ex2']
- property batchsize
返回批次大小
- class speechbrain.dataio.batch.BatchsizeGuesser[source]
基类:
object
尝试找出批次大小,但绝不报错
如果无法确定其他任何值,将回退到猜测为 1
示例
>>> guesser = BatchsizeGuesser() >>> # Works with simple tensors: >>> guesser(torch.randn((2,3))) 2 >>> # Works with sequences of tensors: >>> guesser((torch.randn((2,3)), torch.randint(high=5, size=(2,)))) 2 >>> # Works with PaddedBatch: >>> guesser(PaddedBatch([{"wav": [1.,2.,3.]}, {"wav": [4.,5.,6.]}])) 2 >>> guesser("Even weird non-batches have a fallback") 1