speechbrain.lobes.models.transformer.TransformerST 模块

SpeechBrain 风格的 ST Transformer。

作者 * YAO FEI, CHENG 2021

摘要

TransformerST

这是 ST 的 Transformer 模型实现。

参考

class speechbrain.lobes.models.transformer.TransformerST.TransformerST(tgt_vocab, input_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, positional_encoding='fixed_abs_sine', normalize_before=False, kernel_size: int | None = 31, bias: bool | None = True, encoder_module: str | None = 'transformer', conformer_activation: ~torch.nn.modules.module.Module | None = <class 'speechbrain.nnet.activations.Swish'>, attention_type: str | None = 'regularMHA', max_length: int | None = 2500, causal: bool | None = True, ctc_weight: float = 0.0, asr_weight: float = 0.0, mt_weight: float = 0.0, asr_tgt_vocab: int = 0, mt_src_vocab: int = 0)[source]

基类: TransformerASR

这是 ST 的 Transformer 模型实现。

该架构基于论文“Attention Is All You Need”:https://arxiv.org/pdf/1706.03762.pdf

参数:
  • tgt_vocab (int) – 词汇表大小。

  • input_size (int) – 输入特征大小。

  • d_model (int, optional) – 嵌入维度大小。(默认为 512)。

  • nhead (int, optional) – 多头注意力模型中的头数。(默认为 8)。

  • num_encoder_layers (int, optional) – 编码器中的子编码器层数。(默认为 6)。

  • num_decoder_layers (int, optional) – 解码器中的子解码器层数。(默认为 6)。

  • d_ffn (int, optional) – 前馈网络模型的维度。(默认为 2048)。

  • dropout (int, optional) – dropout 值。(默认为 0.1)。

  • activation (torch.nn.Module, optional) – FFN 层的激活函数。推荐:relu 或 gelu (默认为 relu)。

  • positional_encoding (str, optional) – 使用的位置编码类型。例如,'fixed_abs_sine' 表示固定的绝对位置编码。

  • normalize_before (bool, optional) – 在 Transformer 层中,归一化是应用于 MHA 或 FFN 之前还是之后。默认为 True,因为这已被证明可以带来更好的性能和训练稳定性。

  • kernel_size (int, optional) – 使用 Conformer 时卷积层中的核大小。

  • bias (bool, optional) – 在 Conformer 卷积层中是否使用偏置。

  • encoder_module (str, optional) – 选择编码器使用 Conformer 还是 Transformer。解码器固定为 Transformer。

  • conformer_activation (torch.nn.Module, optional) – Conformer 卷积层之后使用的激活模块。例如 Swish, ReLU 等。必须是一个 torch Module。

  • attention_type (str, optional) – 所有 Transformer 或 Conformer 层中使用的注意力层类型。例如 regularMHA 或 RelPosMHA。

  • max_length (int, optional) – 输入中目标序列和源序列的最大长度。用于位置编码。

  • causal (bool, optional) – 编码器是否应该是因果的(解码器总是因果的)。如果是因果的,则 Conformer 卷积层是因果的。

  • ctc_weight (float) – asr 任务的 ctc 权重

  • asr_weight (float) – 用于计算损失的 asr 任务权重

  • mt_weight (float) – 用于计算损失的 mt 任务权重

  • asr_tgt_vocab (int) – asr 目标语言的词汇表大小

  • mt_src_vocab (int) – mt 源语言的词汇表大小

示例

>>> src = torch.rand([8, 120, 512])
>>> tgt = torch.randint(0, 720, [8, 120])
>>> net = TransformerST(
...     720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU,
...     ctc_weight=1, asr_weight=0.3,
... )
>>> enc_out, dec_out = net.forward(src, tgt)
>>> enc_out.shape
torch.Size([8, 120, 512])
>>> dec_out.shape
torch.Size([8, 120, 512])
forward_asr(encoder_out, src, tgt, wav_len, pad_idx=0)[source]

此方法实现了 asr 任务的解码步骤

参数:
  • encoder_out (torch.Tensor) – 编码器的表示 (必需)。

  • src (torch.Tensor) – 输入序列 (必需)。

  • tgt (torch.Tensor) – 解码器的序列 (转录) (必需)。

  • wav_len (torch.Tensor) – 输入张量的长度 (必需)。

  • pad_idx (int) – <pad> token 的索引 (默认为 0)。

返回:

asr_decoder_out – asr 解码器的一个步骤。

返回类型:

torch.Tensor

forward_mt(src, tgt, pad_idx=0)[source]

此方法实现了 mt 任务的前向传播步骤

参数:
  • src (torch.Tensor) – 编码器的序列 (转录) (必需)。

  • tgt (torch.Tensor) – 解码器的序列 (翻译) (必需)。

  • pad_idx (int) – <pad> token 的索引 (默认为 0)。

返回:

  • encoder_out (torch.Tensor) – 编码器的输出

  • decoder_out (torch.Tensor) – 解码器的输出

forward_mt_decoder_only(src, tgt, pad_idx=0)[source]

此方法实现了使用 wav2vec 编码器进行 mt 任务的前向传播步骤 (与上述相同,但没有编码器堆栈)

参数:
  • (转录) (src) – w2v2 编码器的输出特征

  • (翻译) (tgt) – 解码器的序列 (必需)。

  • pad_idx (int) – <pad> token 的索引 (默认为 0)。

decode_asr(tgt, encoder_out)[source]

此方法实现了 Transformer 模型的解码步骤。

参数:
  • tgt (torch.Tensor) – 解码器的序列。

  • encoder_out (torch.Tensor) – 编码器的隐藏层输出。

返回:

  • prediction (torch.Tensor) – 预测输出。

  • multihead_attns (torch.Tensor) – 注意力的最后一步。

make_masks_for_mt(src, tgt, pad_idx=0)[source]

此方法生成用于训练 Transformer 模型的掩码。

参数:
  • src (torch.Tensor) – 编码器的序列 (必需)。

  • tgt (torch.Tensor) – 解码器的序列 (必需)。

  • pad_idx (int) – <pad> token 的索引 (默认为 0)。

返回:

  • src_key_padding_mask (torch.Tensor) – 因填充需要掩码的时间步

  • tgt_key_padding_mask (torch.Tensor) – 因填充需要掩码的时间步

  • src_mask (torch.Tensor) – 因因果性需要掩码的时间步

  • tgt_mask (torch.Tensor) – 因因果性需要掩码的时间步