speechbrain.lobes.models.ContextNet 模块

SpeechBrain 实现的 ContextNet,参考 https://arxiv.org/pdf/2005.03191.pdf

作者
  • Jianyuan Zhong 2020

摘要

ContextNet

此类实现了 ContextNet。

ContextNetBlock

此类实现了 ContextNet 中的一个块。

SEmodule

此类实现了 Squeeze-and-Excitation 模块。

参考

class speechbrain.lobes.models.ContextNet.ContextNet(input_shape, out_channels=640, conv_channels=None, kernel_size=3, strides=None, num_blocks=21, num_layers=5, inner_dim=12, alpha=1, beta=1, dropout=0.15, activation=<class 'speechbrain.nnet.activations.Swish'>, se_activation=<class 'torch.nn.modules.activation.Sigmoid'>, norm=<class 'speechbrain.nnet.normalization.BatchNorm1d'>, residuals=None)[source]

基类:Sequential

此类实现了 ContextNet。

参考论文:https://arxiv.org/pdf/2005.03191.pdf

参数:
  • input_shape (tuple) – 输入的预期形状。

  • out_channels (int) – 此模型的输出通道数(默认为 640)。

  • conv_channels (Optional (list[int])) – 每个 contextnet 块的输出通道数。如果未提供,将使用上述论文的默认设置进行初始化。

  • kernel_size (int) – 卷积层的核大小(默认为 3)。

  • strides (Optional (list[int])) – 每个上下文块的步进因子。此步进应用于每个上下文块的最后一个卷积层。如果未提供,将使用上述论文的默认设置进行初始化。

  • num_blocks (int) – 上下文块的数量(默认为 21)。

  • num_layers (int) – 每个上下文块的深度可分离卷积层数(默认为 5)。

  • inner_dim (int) – SE 模块瓶颈网络的内部维度(默认为 12)。

  • alpha (float) – 用于缩放网络输出通道的因子(默认为 1)。

  • beta (float) – 用于缩放 Swish 激活的 Beta 值(默认为 1)。

  • dropout (float) – Dropout 值(默认为 0.15)。

  • activation (torch class) – 每个上下文块的激活函数(默认为 Swish)。

  • se_activation (torch class) – SE 模块的激活函数(默认为 torch.nn.Sigmoid)。

  • norm (torch class) – 用于正则化模型的归一化方法(默认为 BatchNorm1d)。

  • residuals (Optional (list[bool])) – 是否在每个上下文块应用残差连接(默认为 None)。

示例

>>> inp = torch.randn([8, 48, 40])
>>> block = ContextNet(input_shape=inp.shape, num_blocks=14)
>>> out = block(inp)
>>> out.shape
torch.Size([8, 6, 640])
class speechbrain.lobes.models.ContextNet.SEmodule(input_shape, inner_dim, activation=<class 'torch.nn.modules.activation.Sigmoid'>, norm=<class 'speechbrain.nnet.normalization.BatchNorm1d'>)[source]

基类: Module

此类实现了 Squeeze-and-Excitation 模块。

参数:
  • input_shape (tuple) – 输入的预期形状。

  • inner_dim (int) – SE 模块瓶颈网络的内部维度(默认为 12)。

  • activation (torch class) – SE 模块的激活函数(默认为 torch.nn.Sigmoid)。

  • norm (torch class) – 用于正则化模型的归一化方法(默认为 BatchNorm1d)。

示例

>>> inp = torch.randn([8, 120, 40])
>>> net = SEmodule(input_shape=inp.shape, inner_dim=64)
>>> out = net(inp)
>>> out.shape
torch.Size([8, 120, 40])
forward(x)[source]

处理输入张量 x 并返回输出张量。

class speechbrain.lobes.models.ContextNet.ContextNetBlock(out_channels, kernel_size, num_layers, inner_dim, input_shape, stride=1, beta=1, dropout=0.15, activation=<class 'speechbrain.nnet.activations.Swish'>, se_activation=<class 'torch.nn.modules.activation.Sigmoid'>, norm=<class 'speechbrain.nnet.normalization.BatchNorm1d'>, residual=True)[source]

基类: Module

此类实现了 ContextNet 中的一个块。

参数:
  • out_channels (int) – 此模型的输出通道数(默认为 640)。

  • kernel_size (int) – 卷积层的核大小(默认为 3)。

  • num_layers (int) – 此上下文块的深度可分离卷积层数(默认为 5)。

  • inner_dim (int) – SE 模块瓶颈网络的内部维度(默认为 12)。

  • input_shape (tuple) – 输入的预期形状。

  • stride (int) – 此上下文块的步进因子(默认为 1)。

  • beta (float) – 用于缩放 Swish 激活的 Beta 值(默认为 1)。

  • dropout (float) – Dropout 值(默认为 0.15)。

  • activation (torch class) – 此上下文块的激活函数(默认为 Swish)。

  • se_activation (torch class) – SE 模块的激活函数(默认为 torch.nn.Sigmoid)。

  • norm (torch class) – 用于正则化模型的归一化方法(默认为 BatchNorm1d)。

  • residual (bool) – 是否在此上下文块应用残差连接(默认为 None)。

示例

>>> inp = torch.randn([8, 120, 40])
>>> block = ContextNetBlock(256, 3, 5, 12, input_shape=inp.shape, stride=2)
>>> out = block(inp)
>>> out.shape
torch.Size([8, 60, 256])
forward(x)[source]

处理输入张量 x 并返回输出张量。