speechbrain.lobes.models.BESTRQ 模块
支持原始论文中描述的 BEST RQ 训练的少量组件:https://arxiv.org/pdf/2202.01855。
作者 * Ryan Whetten 2024 * Titouan Parcollet 2025
摘要
函数
这从样本列表中创建一个批次,并创建将用于掩码 BEST-RQ 输入的掩码。 |
|
此函数生成 BEST-RQ 的掩码。 |
参考
- speechbrain.lobes.models.BESTRQ.compute_mask(shape, sample_lens, mask_prob, mask_length)[源代码]
此函数生成 BEST-RQ 的掩码。
它为整个批次生成一个唯一的掩码,并基于较短的utterance。这很重要,因为如果批次包含一个短句子和许多长句子,训练可能会受到影响,因为只有少量帧会被掩码。
特别是,从传递给 sample_lens 的较小长度中,我们将生成 N 个掩码,其中 N = mask_prob * 最小长度。因此,mask_prob 是一个帧开始一个掩码的概率,而不是被掩码的概率。
如果一个句子长度是 100 个时间步,mask_prob 为 0.15,掩码大小为 4,结果将有 100*0.15*4=60% 的帧被掩码。
- 参数:
- 返回类型:
计算出的掩码
示例
>>> compute_mask((2,50,60), [40, 50], 0.15, 2).shape torch.Size([12])
- speechbrain.lobes.models.BESTRQ.brq_mask_collate_fn(samples_lst, get_out_len_fn, mask_prob, mask_length, n_mels)[源代码]
这从样本列表中创建一个批次,并创建将用于掩码 BEST-RQ 输入的掩码。为了创建掩码,我们需要知道潜在提取器后的输出形状,因此需要参数
get_out_len_fn
。也可以为每个样本创建掩码(加载音频文件时),然后进行整理,但在那个时候不知道批次中最短样本的长度(它决定了掩码帧的数量),所以最好用这种方式。- 参数:
- 返回:
wavs_padded (torch.Tensor, 形状 (B, T)) – 右侧填充的音频数组。
wav_lens (torch.Tensor, 形状 (B,)) – 每个样本非填充部分的百分比。
mask (torch.Tensor, 形状 (T)) – 输入张量中需要掩码索引的掩码。