speechbrain.lobes.models.transformer.TransformerASR 模块

SpeechBrain 风格的 ASR Transformer。

作者 * Jianyuan Zhong 2020 * Titouan Parcollet 2024 * Luca Della Libera 2024 * Shucong Zhang 2024

概要

EncoderWrapper

这是任何 ASR Transformer 编码器的包装器。

TransformerASR

这是 ASR 的 Transformer 模型实现。

TransformerASRStreamingContext

TransformerASR 实例的流式元数据和状态。

函数

make_transformer_src_mask

准备源 Transformer 掩码,根据因果或其他简单的受限注意力方法,限制哪些帧可以关注哪些帧。

make_transformer_src_tgt_masks

此函数生成用于训练 Transformer 模型的掩码,适用于 ASR 上下文,包含编码掩码,如果指定了 tgt,则可选包含解码掩码。

参考

class speechbrain.lobes.models.transformer.TransformerASR.TransformerASRStreamingContext(dynchunktrain_config: DynChunkTrainConfig, encoder_context: Any)[source]

基类: object

TransformerASR 实例的流式元数据和状态。

dynchunktrain_config: DynChunkTrainConfig

动态分块训练配置,包含分块大小和上下文大小信息。

encoder_context: Any

不透明的编码器上下文信息。它由编码器的 make_streaming_context 方法构建,并在使用 encode_streaming 时传递给编码器。

speechbrain.lobes.models.transformer.TransformerASR.make_transformer_src_mask(src: Tensor, causal: bool = False, dynchunktrain_config: DynChunkTrainConfig | None = None) Tensor | None[source]

准备源 Transformer 掩码,根据因果或其他简单的受限注意力方法,限制哪些帧可以关注哪些帧。

参数:
  • src (torch.Tensor) – 用于构建掩码的源张量。目前实际上不使用张量的内容;仅使用其形状和其他元数据(例如设备)。

  • causal (bool) – 是否使用严格因果性。帧将无法关注未来的任何帧。

  • dynchunktrain_config (DynChunkTrainConfig, 可选) – 动态分块训练配置。这实现了一种简单的分块注意力形式。与 causal 不兼容。

返回值:

一个形状为 (timesteps, timesteps) 的布尔掩码 Tensor。

返回类型:

torch.Tensor

speechbrain.lobes.models.transformer.TransformerASR.make_transformer_src_tgt_masks(src, tgt=None, wav_len=None, pad_idx=0, causal: bool = False, dynchunktrain_config: DynChunkTrainConfig | None = None)[source]

此函数生成用于训练 Transformer 模型的掩码,适用于 ASR 上下文,包含编码掩码,如果指定了 tgt,则可选包含解码掩码。

参数:
  • src (torch.Tensor) – 传递给编码器的序列(必需)。

  • tgt (torch.Tensor) – 传递给解码器的序列。

  • wav_len (torch.Tensor) – 输入的长度。

  • pad_idx (int) – <pad> token 的索引(默认=0)。

  • causal (bool) – 是否使用严格因果性。参见 make_asr_src_mask

  • dynchunktrain_config (DynChunkTrainConfig, 可选) – 动态分块训练配置。参见 make_asr_src_mask

返回值:

  • src_key_padding_mask (torch.Tensor) – 用于忽略填充的键填充掩码

  • tgt_key_padding_mask (torch.Tensor) – 用于忽略填充的键填充掩码

  • src_mask (torch.Tensor) – 用于忽略无效(例如未来)时间步长的掩码

  • tgt_mask (torch.Tensor) – 用于忽略无效(例如未来)时间步长的掩码

class speechbrain.lobes.models.transformer.TransformerASR.TransformerASR(tgt_vocab, input_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, d_ffn=2048, dropout=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, positional_encoding='fixed_abs_sine', normalize_before=False, kernel_size: int | None = 31, bias: bool = True, encoder_module: str = 'transformer', conformer_activation: type = <class 'speechbrain.nnet.activations.Swish'>, branchformer_activation: type = <class 'torch.nn.modules.activation.GELU'>, attention_type: str = 'regularMHA', max_length: int = 2500, causal: bool | None = None, csgu_linear_units: int = 3072, gate_activation: type = <class 'torch.nn.modules.linear.Identity'>, use_linear_after_conv: bool = False, output_hidden_states=False, layerdrop_prob=0.0)[source]

基类: TransformerInterface

这是 ASR 的 Transformer 模型实现。

此架构基于论文“Attention Is All You Need”:https://arxiv.org/pdf/1706.03762.pdf

参数:
  • tgt_vocab (int) – 词汇表大小。

  • input_size (int) – 输入特征大小。

  • d_model (int, 可选) – Embedding 维度大小。(默认=512)。

  • nhead (int, 可选) – 多头注意力模型中的头数(默认=8)。

  • num_encoder_layers (int, 可选) – 编码器中的子编码器层数(默认=6)。

  • num_decoder_layers (int, 可选) – 解码器中的子解码器层数(默认=6)。

  • d_ffn (int, 可选) – 前馈网络模型的维度(默认=2048)。

  • dropout (int, 可选) – Dropout 值(默认=0.1)。

  • activation (torch.nn.Module, 可选) – FFN 层的激活函数。推荐:relu 或 gelu(默认=relu)。

  • positional_encoding (str, 可选) – 使用的 positional encoding 类型。例如,‘fixed_abs_sine’ 用于固定绝对 positional encoding。

  • normalize_before (bool, 可选) – 在 Transformer 层中,归一化是应用于 MHA 或 FFN 之前还是之后。默认为 True,因为这已被证明可以带来更好的性能和训练稳定性。

  • kernel_size (int, 可选) – 使用 Conformer 时卷积层中的 kernel size。

  • bias (bool, 可选) – 在 Conformer 卷积层中是否使用偏置。

  • encoder_module (str, 可选) – 选择 Conformer 或 Transformer 作为编码器。解码器固定为 Transformer。

  • conformer_activation (torch.nn.Module, 可选) – Conformer 卷积层后使用的激活模块。例如 Swish、ReLU 等。它必须是一个 torch Module。

  • branchformer_activation (torch.nn.Module, 可选) – Branchformer 编码器中使用的激活模块。例如 Swish、ReLU 等。它必须是一个 torch Module。

  • attention_type (str, 可选) – 所有 Transformer 或 Conformer 层中使用的注意力层类型。例如 regularMHA 或 RelPosMHA。

  • max_length (int, 可选) – 输入中目标和源序列的最大长度。用于 positional encoding。

  • causal (bool, 可选) – 编码器是否应该具有因果性(解码器始终具有因果性)。如果具有因果性,Conformer 卷积层也是因果的。

  • csgu_linear_units (int, 可选) – CSGU 模块隐藏线性单元中的神经元数量。-> Branchformer

  • gate_activation (torch.nn.Module, 可选) – CSGU 模块门控处使用的激活函数。-> Branchformer

  • use_linear_after_conv (bool, 可选) – 如果为 True,将应用大小为 input_size//2 的线性变换。-> Branchformer

  • output_hidden_states (bool, 可选) – 模型是否应将隐藏状态作为张量列表输出。

  • layerdrop_prob (float) – 丢弃整个层的概率。

示例

>>> src = torch.rand([8, 120, 512])
>>> tgt = torch.randint(0, 720, [8, 120])
>>> net = TransformerASR(
...     720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU
... )
>>> enc_out, dec_out = net.forward(src, tgt)
>>> enc_out.shape
torch.Size([8, 120, 512])
>>> dec_out.shape
torch.Size([8, 120, 512])
forward(src, tgt, wav_len=None, pad_idx=0)[source]
参数:
  • src (torch.Tensor) – 传递给编码器的序列。

  • tgt (torch.Tensor) – 传递给解码器的序列。

  • wav_len (torch.Tensor, 可选) – Torch Tensor,形状为 (batch, ),包含每个示例相对于填充长度的相对长度。

  • pad_idx (int, 可选) – <pad> token 的索引(默认=0)。

返回值:

  • encoder_out (torch.Tensor) – 编码器的输出。

  • decoder_out (torch.Tensor) – 解码器的输出

  • hidden_state_lst (list, 可选) – 编码器隐藏层的输出。仅在 output_hidden_states 设置为 true 时有效。

decode(tgt, encoder_out, enc_len=None)[source]

此方法实现了 Transformer 模型的解码步骤。

参数:
  • tgt (torch.Tensor) – 传递给解码器的序列。

  • encoder_out (torch.Tensor) – 编码器的隐藏层输出。

  • enc_len (torch.LongTensor) – 编码器状态的实际长度。

返回类型:

预测

encode(src, wav_len=None, pad_idx=0, dynchunktrain_config: DynChunkTrainConfig | None = None)[source]

编码器 forward 传递

参数:
  • src (torch.Tensor) – 传递给编码器的序列。

  • wav_len (torch.Tensor, 可选) – Torch Tensor,形状为 (batch, ),包含每个示例相对于填充长度的相对长度。

  • pad_idx (int) – 用于填充的索引。

  • dynchunktrain_config (DynChunkTrainConfig) – 动态分块配置。

返回值:

encoder_out

返回类型:

torch.Tensor

encode_streaming(src, context: TransformerASRStreamingContext)[source]

流式编码器 forward 传递

参数:
  • src (torch.Tensor) – 传递给编码器的序列(分块)。

  • context (TransformerASRStreamingContext) – 对流式上下文的可变引用。它保存了在分块推理中需要持久化的状态,可以使用 make_streaming_context 构建。此函数会修改它。

返回类型:

此分块的编码器输出。

示例

>>> import torch
>>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
>>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
>>> net = TransformerASR(
...     tgt_vocab=100,
...     input_size=64,
...     d_model=64,
...     nhead=8,
...     num_encoder_layers=1,
...     num_decoder_layers=0,
...     d_ffn=128,
...     attention_type="RelPosMHAXL",
...     positional_encoding=None,
...     encoder_module="conformer",
...     normalize_before=True,
...     causal=False,
... )
>>> ctx = net.make_streaming_context(DynChunkTrainConfig(16, 1))
>>> src1 = torch.rand([8, 16, 64])
>>> src2 = torch.rand([8, 16, 64])
>>> out1 = net.encode_streaming(src1, ctx)
>>> out1.shape
torch.Size([8, 16, 64])
>>> ctx.encoder_context.layers[0].mha_left_context.shape
torch.Size([8, 16, 64])
>>> out2 = net.encode_streaming(src2, ctx)
>>> out2.shape
torch.Size([8, 16, 64])
>>> ctx.encoder_context.layers[0].mha_left_context.shape
torch.Size([8, 16, 64])
>>> combined_out = torch.concat((out1, out2), dim=1)
>>> combined_out.shape
torch.Size([8, 32, 64])
make_streaming_context(dynchunktrain_config: DynChunkTrainConfig, encoder_kwargs={})[source]

为此 Transformer 及其编码器创建一个空白的流式上下文。

参数:
  • dynchunktrain_config (DynChunkTrainConfig) – 运行时分块注意力配置。

  • encoder_kwargs (dict) – 传递给编码器 make_streaming_context 的参数。编码器所需的元数据可能因编码器而异。

返回类型:

TransformerASRStreamingContext

class speechbrain.lobes.models.transformer.TransformerASR.EncoderWrapper(transformer, *args, **kwargs)[source]

基类: Module

这是任何 ASR Transformer 编码器的包装器。默认情况下,TransformerASR 的 .forward() 函数既编码也解码。使用此包装器后,.forward() 函数仅执行 .encode()。

重要:TransformerASR 类必须包含一个 .encode() 函数。

参数:
  • transformer (sb.lobes.models.TransformerInterface) – 一个包含 .encode() 函数的 Transformer 实例。

  • *args (tuple)

  • **kwargs (dict) – 传递给父类的参数。

示例

>>> src = torch.rand([8, 120, 512])
>>> tgt = torch.randint(0, 720, [8, 120])
>>> net = TransformerASR(
...     720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU
... )
>>> encoder = EncoderWrapper(net)
>>> enc_out = encoder(src)
>>> enc_out.shape
torch.Size([8, 120, 512])
forward(x, wav_lens=None, pad_idx=0, **kwargs)[source]

处理输入张量 x 并返回一个输出张量。

forward_streaming(x, context)[source]

处理输入的音频分块张量 x,使用并更新可变的编码器 context

make_streaming_context(*args, **kwargs)[source]

初始化流式上下文。将所有参数转发到底层 Transformer。参见 speechbrain.lobes.models.transformer.TransformerASR.make_streaming_context()