speechbrain.lobes.models.transformer.Branchformer 模块
Branchformer 实现。
参考:“Branchformer: Parallel MLP-Attention Architectures to Capture Local and Global Context for Speech Recognition and Understanding”
来源:部分代码可能改编自 ESPNet。
作者 * Titouan Parcollet 2023
摘要
类
此类实现了 Branchformer encoder。 |
|
这是 Branchformer encoder 层的实现。 |
|
这是 Branchformer 中卷积分支的实现。 |
参考
- class speechbrain.lobes.models.transformer.Branchformer.ConvolutionBranch(input_size, linear_units=3072, kernel_size=31, activation=<class 'torch.nn.modules.activation.GELU'>, gate_activation=<class 'torch.nn.modules.linear.Identity'>, dropout=0.0, use_linear_after_conv=False)[source]
基类:
Module
这是 Branchformer 中卷积分支的实现。
默认结构为:LN -> Channel Proj -> GeLU -> (CNN Spatial Gating) -> Channel Proj -> Dropout
- 参数:
input_size (int) – 特征(通道)维度的预期大小。
linear_units (int, optional) – 隐藏线性单元中的神经元数量。
kernel_size (int, optional) – 非瓶颈卷积层的核大小。
activation (torch.nn.Module, optional) – 预投影后使用的激活函数。
gate_activation (torch.nn.Module, optional) – 在 CSGU 模块门控处使用的激活函数。
dropout (float, optional) – Dropout 率。
use_linear_after_conv (bool, optional) – 如果为 True,将应用 input_size//2 大小的线性变换
示例
>>> x = torch.rand((8, 60, 512)) >>> net = ConvolutionBranch(512, 1024) >>> output = net(x) >>> output.shape torch.Size([8, 60, 512])
- class speechbrain.lobes.models.transformer.Branchformer.BranchformerEncoderLayer(d_model, nhead, kernel_size=31, kdim=None, vdim=None, activation=<class 'torch.nn.modules.activation.GELU'>, dropout=0.0, attention_type='RelPosMHAXL', csgu_linear_units=3072, gate_activation=<class 'torch.nn.modules.linear.Identity'>, use_linear_after_conv=False)[source]
基类:
Module
这是 Branchformer encoder 层的实现。
- 参数:
d_model (int) – 输入嵌入的预期大小。
nhead (int) – Attention 头数量。
kernel_size (int, optional) – 卷积模型的核大小。
kdim (int, optional) – 键的维度。
vdim (int, optional) – 值的维度。
activation (torch.nn.Module) – 在每个 Conformer 层中使用的激活函数。
dropout (int, optional) – Encoder 的 Dropout。
attention_type (str, optional) – Attention 层类型,例如 regularMHA 用于常规 MultiHeadAttention。
csgu_linear_units (int, optional) – CSGU 模块隐藏线性单元中的神经元数量。
gate_activation (torch.nn.Module, optional) – 在 CSGU 模块门控处使用的激活函数。
use_linear_after_conv (bool, optional) – 如果为 True,将应用 input_size//2 大小的线性变换
示例
>>> import torch >>> x = torch.rand((8, 60, 512)) >>> pos_embs = torch.rand((1, 2*60-1, 512)) >>> net = BranchformerEncoderLayer(nhead=8, d_model=512, kernel_size=3) >>> output = net(x, pos_embs=pos_embs) >>> output[0].shape torch.Size([8, 60, 512])
- forward(x, src_mask: Tensor | None = None, src_key_padding_mask: Tensor | None = None, pos_embs: Tensor | None = None, dynchunktrain_config=None)[source]
- 参数:
x (torch.Tensor) – 输入到 encoder 层的序列。
src_mask (torch.Tensor, optional) – src 序列的掩码。
src_key_padding_mask (torch.Tensor, optional) – 每批次 src 键的掩码。
pos_embs (torch.Tensor, torch.nn.Module, optional) – 包含输入序列位置嵌入的模块或 tensor
- class speechbrain.lobes.models.transformer.Branchformer.BranchformerEncoder(num_layers, d_model, nhead, kernel_size=31, kdim=None, vdim=None, activation=<class 'torch.nn.modules.activation.GELU'>, dropout=0.0, attention_type='RelPosMHAXL', csgu_linear_units=3072, gate_activation=<class 'torch.nn.modules.linear.Identity'>, use_linear_after_conv=False, output_hidden_states=False, layerdrop_prob=0.0)[source]
基类:
Module
此类实现了 Branchformer encoder。
- 参数:
num_layers (int) – 层数。
d_model (int) – 嵌入维度大小。
nhead (int) – Attention 头数量。
kernel_size (int, optional) – 卷积模型的核大小。
kdim (int, optional) – 键的维度。
vdim (int, optional) – 值的维度。
activation (torch.nn.Module) – 在每个 Confomer 层中使用的激活函数。
dropout (int, optional) – Encoder 的 Dropout。
attention_type (str, optional) – Attention 层类型,例如 regularMHA 用于常规 MultiHeadAttention。
csgu_linear_units (int, optional) – CSGU 模块隐藏线性单元中的神经元数量。
gate_activation (torch.nn.Module, optional) – 在 CSGU 模块门控处使用的激活函数。
use_linear_after_conv (bool, optional) – 如果为 True,将应用 input_size//2 大小的线性变换。
output_hidden_states (bool, optional) – 模型是否应将隐藏状态作为 tensor 列表输出。
layerdrop_prob (float) – 丢弃整个层的概率。
示例
>>> import torch >>> x = torch.rand((8, 60, 512)) >>> pos_emb = torch.rand((1, 2*60-1, 512)) >>> net = BranchformerEncoder(1, 512, 8) >>> output, _ = net(x, pos_embs=pos_emb) >>> output.shape torch.Size([8, 60, 512])
>>> import torch >>> x = torch.rand((8, 60, 512)) >>> pos_emb = torch.rand((1, 2*60-1, 512)) >>> net = BranchformerEncoder(1, 512, 8, output_hidden_states=True) >>> output, attn_list, hidden_list = net(x, pos_embs=pos_emb) >>> hidden_list[0].shape torch.Size([8, 60, 512]) >>> len(hidden_list) 2
- forward(src, src_mask: Tensor | None = None, src_key_padding_mask: Tensor | None = None, pos_embs: Tensor | None = None, dynchunktrain_config=None)[source]
- 参数:
src (torch.Tensor) – 输入到 encoder 层的序列。
src_mask (torch.Tensor, optional) – src 序列的掩码。
src_key_padding_mask (torch.Tensor, optional) – 每批次 src 键的掩码。
pos_embs (torch.Tensor, torch.nn.Module,) – 包含输入序列位置嵌入的模块或 tensor。如果提供了自定义 pos_embs,其形状需要为 (1, 2*S-1, E),其中 S 是序列长度,E 是嵌入维度。
dynchunktrain_config (None) – 此 encoder 不支持此配置。
- 返回:
output (torch.Tensor) – Conformer 的输出。
attention_lst (list) – attention 值。
hidden_state_lst (list, optional) – encoder 隐藏层的输出。仅当 output_hidden_states 设置为 true 时有效。