speechbrain.lobes.models.BESTRQ 模块

支持原始论文中描述的 BEST RQ 训练的少量组件:https://arxiv.org/pdf/2202.01855

作者 * Ryan Whetten 2024 * Titouan Parcollet 2025

摘要

函数

brq_mask_collate_fn

这从样本列表中创建一个批次,并创建将用于掩码 BEST-RQ 输入的掩码。

compute_mask

此函数生成 BEST-RQ 的掩码。

参考

speechbrain.lobes.models.BESTRQ.compute_mask(shape, sample_lens, mask_prob, mask_length)[源代码]

此函数生成 BEST-RQ 的掩码。

它为整个批次生成一个唯一的掩码,并基于较短的utterance。这很重要,因为如果批次包含一个短句子和许多长句子,训练可能会受到影响,因为只有少量帧会被掩码。

特别是,从传递给 sample_lens 的较小长度中,我们将生成 N 个掩码,其中 N = mask_prob * 最小长度。因此,mask_prob 是一个帧开始一个掩码的概率,而不是被掩码的概率。

如果一个句子长度是 100 个时间步,mask_prob 为 0.15,掩码大小为 4,结果将有 100*0.15*4=60% 的帧被掩码。

参数:
  • shape (tuple) – 要被掩码的输入张量的形状。通常是 (Batch, Time, Fea)。

  • sample_lens (list) – 批次中每个样本的帧数对应的整数列表。例如 (12,13,14,20)

  • mask_prob (float) – 一个帧生成掩码的概率。已经被掩码的帧不能生成新的掩码。

  • mask_length (int) – 一个掩码覆盖的帧数。

返回类型:

计算出的掩码

示例

>>> compute_mask((2,50,60), [40, 50], 0.15, 2).shape
torch.Size([12])
speechbrain.lobes.models.BESTRQ.brq_mask_collate_fn(samples_lst, get_out_len_fn, mask_prob, mask_length, n_mels)[源代码]

这从样本列表中创建一个批次,并创建将用于掩码 BEST-RQ 输入的掩码。为了创建掩码,我们需要知道潜在提取器后的输出形状,因此需要参数 get_out_len_fn。也可以为每个样本创建掩码(加载音频文件时),然后进行整理,但在那个时候不知道批次中最短样本的长度(它决定了掩码帧的数量),所以最好用这种方式。

参数:
  • samples_lst (list) – 由 audio_pipeline 返回的样本列表。

  • get_out_len_fn (function) – 计算样本通过特征提取器后长度的函数。

  • mask_prob (float) – 一个帧生成掩码的概率。已经被掩码的帧不能生成新的掩码。

  • mask_length (int) – 将被掩码的连续帧数。

  • n_mels (int) – 输入张量最后一个维度中的 Mels 滤波器组数量。

返回:

  • wavs_padded (torch.Tensor, 形状 (B, T)) – 右侧填充的音频数组。

  • wav_lens (torch.Tensor, 形状 (B,)) – 每个样本非填充部分的百分比。

  • mask (torch.Tensor, 形状 (T)) – 输入张量中需要掩码索引的掩码。