speechbrain.augment.augmenter 模块

实现数据增强流程的类。

作者
  • Mirco Ravanelli 2022

摘要

Augmenter

应用数据增强流程。

参考

class speechbrain.augment.augmenter.Augmenter(parallel_augment=False, parallel_augment_fixed_bs=False, concat_original=False, min_augmentations=None, max_augmentations=None, shuffle_augmentations=False, repeat_augment=1, augment_start_index=0, augment_end_index=None, concat_start_index=0, concat_end_index=None, augment_prob=1.0, augmentations=[], enable_augmentations=None)[source]

基类: Module

应用数据增强流程。

参数
  • parallel_augment (bool) – 如果为 False,则按照管道参数指定的顺序顺序应用增强。如果为 True,则所有 N 个增强在输出中沿批次轴连接。

  • parallel_augment_fixed_bs (bool) – 如果为 False,则每个增强器(并行执行)生成的增强示例数量等于批次大小。因此,总的来说,使用此选项会生成 N*批次大小的人工数据,其中 N 是增强器的数量。如果为 True,则增强示例的总数固定为批次大小,因此,对于每个增强器,固定为 batch size // N 个示例。此选项有助于控制与原始数据分布相关的合成示例数量,因为它始终保留 50% 的原始数据和 50% 的增强数据。

  • concat_original (bool) – 如果为 True,则原始输入将与增强输出(沿批次轴)连接。

  • min_augmentations (int) – 应用于输入信号的增强次数在 min_augmentations 和 max_augmentations 之间随机采样。例如,如果增强字典包含 N=6 个增强,我们设置 min_augmentations=1 和 max_augmentations=4,则最多应用 M=4 个增强。选定的增强按照增强字典中指定的顺序应用。如果 shuffle_augmentations = True,则会选择 M 个随机增强。

  • max_augmentations (int) – 应用的最大增强次数。详见 min_augmentations。

  • shuffle_augmentations (bool) – 如果为 True,则打乱增强字典的条目。其作用是随机选择要应用的增强顺序。

  • repeat_augment (int) – 应用增强算法 N 次。这可用于执行更多数据增强。

  • augment_start_index (int) – 输入批次中开始进行数据增强的第一个元素的索引。此参数允许你指定应用数据增强的起始点。

  • augment_end_index (int) – 输入批次中数据增强应停止的最后一个元素的索引。你可以使用此参数定义在批次内应用数据增强的结束点。

  • concat_start_index (int) – 如果将 concat_original 设置为 True,你可以指定原始批次的一部分在输出中连接。使用此参数选择从原始输入批次开始复制的第一个元素的索引。

  • concat_end_index (int) – 如果将 concat_original 设置为 True,你可以指定原始批次的一部分在输出中连接。使用此参数选择从原始输入批次结束复制的最后一个元素的索引。

  • augment_prob (float) – 应用数据增强的概率(0.0 到 1.0)。设置为 0.0 时,返回原始信号而不进行任何增强。设置为 1.0 时,总是应用增强。中间的值决定了增强的可能性。

  • augmentations (list) – 组合以执行数据增强的增强器对象列表。

  • enable_augmentations (list) – 一个布尔值列表,用于选择性地启用或禁用 'augmentations' 列表中的特定增强技术。每个布尔值对应于 'augmentations' 列表中的一个增强对象,并且应具有相同的长度和顺序。此功能对于对增强技术进行消融研究以针对特定任务进行定制非常有用。

示例

>>> from speechbrain.augment.time_domain import DropFreq, DropChunk
>>> freq_dropper = DropFreq()
>>> chunk_dropper = DropChunk(drop_start=100, drop_end=16000)
>>> augment = Augmenter(parallel_augment=False, concat_original=False, augmentations=[freq_dropper, chunk_dropper])
>>> signal = torch.rand([4, 16000])
>>> output_signal, lengths = augment(signal, lengths=torch.tensor([0.2,0.5,0.7,1.0]))
augment(x, lengths, selected_augmentations)[source]

对选定的增强应用数据增强。

参数
  • x (torch.Tensor (batch, time, channel)) – 要增强的输入。

  • lengths (torch.Tensor) – 批次中每个序列的长度。

  • selected_augmentations (dict) – 包含要应用的选定增强的字典。

返回

  • output (torch.Tensor) – 增强后的输出。

  • output_lengths (torch.Tensor) – 每个输出对应的长度。

forward(x, lengths)[source]

应用数据增强。

参数
  • x (torch.Tensor (batch, time, channel)) – 要增强的输入。

  • lengths (torch.Tensor) – 批次中每个序列的长度。

返回

  • output (torch.Tensor) – 增强后的输出。

  • output_lengths (torch.Tensor) – 每个输出对应的长度。

concatenate_outputs(augment_lst, augment_len_lst)[source]

连接增强信号列表,同时考虑不同的时间长度。应用填充以确保所有信号都可以连接。

参数
  • augment_lst (List of torch.Tensor) – 要连接的增强信号列表。

  • augment_len_lst (List of torch.Tensor) – 对应于增强信号的长度列表。

返回

  • concatenated_signals (torch.Tensor) – 包含连接后信号的张量。

  • concatenated_lengths (torch.Tensor) – 包含连接后信号长度的张量。

注意事项

此函数接受一个增强信号列表,这些信号由于速度变化等原因可能具有不同的时间长度。它将信号填充到与输入信号中找到的最大时间维度匹配,并相应地重新缩放长度,然后再进行连接。

replicate_multiple_labels(*args)[source]

沿批次轴复制标签多次,次数与增强次数对应。实际上,并行和连接增强会改变时间维度。

参数

*args (tuple) – 要复制的输入标签张量。可以是单个或 torch.Tensor 列表。

返回

augmented_labels – 对应于增强输入的标签。返回与输入中给出的相同数量的 torch.Tensor。

返回类型

torch.Tensor

replicate_labels(labels)[source]

沿批次轴复制标签多次,次数与增强次数对应。实际上,并行和连接增强会改变时间维度。

参数

labels (torch.Tensor) – 要复制的输入标签张量。

返回

augmented_labels – 对应于增强输入的标签。返回与输入中给出的相同数量的 torch.Tensor。

返回类型

torch.Tensor

check_min_max_augmentations()[source]

检查 min_augmentations 和 max_augmentations 参数。