speechbrain.lobes.models.EnhanceResnet 模块
用于语音增强的 Wide ResNet。
- 作者
Peter Plantinga 2022
摘要
类
卷积块,包含 squeeze-and-excitation。 |
|
基于 Wide ResNet 的增强模型。 |
|
Squeeze-and-excitation 块。 |
参考
- class speechbrain.lobes.models.EnhanceResnet.EnhanceResnet(n_fft=512, win_length=32, hop_length=16, sample_rate=16000, channel_counts=[128, 128, 256, 256, 512, 512], dense_count=2, dense_nodes=1024, activation=GELU(approximate='none'), normalization=<class 'speechbrain.nnet.normalization.BatchNorm2d'>, dropout=0.1, mask_weight=0.99)[source]
基类:
Module
基于 Wide ResNet 的增强模型。
完整模型描述见:https://arxiv.org/pdf/2112.06068.pdf
- 参数:
n_fft (int) – 傅里叶变换的点数,详见
speechbrain.processing.features.STFT
win_length (int) – STFT 窗口长度(毫秒),详见
speechbrain.processing.features.STFT
hop_length (int) – 窗口之间的间隔时间(毫秒),详见
speechbrain.processing.features.STFT
sample_rate (int) – 输入音频每秒的采样点数。
channel_counts (list of ints) – 每个 CNN 块的输出通道数。决定块的数量。
dense_count (int) – 全连接层数量。
dense_nodes (int) – 全连接层中的节点数。
activation (function) – 应用于卷积层之前的函数。
normalization (class) – 用于构建归一化层的类名。
dropout (float) – 训练期间要丢弃的层输出比例(介于 0 和 1 之间)。
mask_weight (float) – 给予掩码的权重。0 - 无掩码,1 - 完全掩码。
示例
>>> inputs = torch.rand([10, 16000]) >>> model = EnhanceResnet() >>> outputs, feats = model(inputs) >>> outputs.shape torch.Size([10, 15872]) >>> feats.shape torch.Size([10, 63, 257])
- class speechbrain.lobes.models.EnhanceResnet.ConvBlock(input_shape, channels, activation=GELU(approximate='none'), normalization=<class 'speechbrain.nnet.normalization.LayerNorm'>, dropout=0.1)[source]
基类:
Module
卷积块,包含 squeeze-and-excitation。
- 参数:
示例
>>> inputs = torch.rand([10, 20, 30, 128]) >>> block = ConvBlock(input_shape=inputs.shape, channels=256) >>> outputs = block(inputs) >>> outputs.shape torch.Size([10, 20, 15, 256])
- class speechbrain.lobes.models.EnhanceResnet.SEblock(input_size)[source]
基类:
Module
Squeeze-and-excitation 块。
定义于:https://arxiv.org/abs/1709.01507
- 参数:
input_size (tuple of ints) – 输入张量的期望大小
示例
>>> inputs = torch.rand([10, 20, 30, 256]) >>> se_block = SEblock(input_size=inputs.shape[-1]) >>> outputs = se_block(inputs) >>> outputs.shape torch.Size([10, 1, 1, 256])