speechbrain.decoders.seq2seq 模块

seq2seq 自回归模型的解码方法。

作者
  • Adel Moumen 2022, 2023, 2024

  • Ju-Chieh Chou 2020

  • Peter Plantinga 2020

  • Mirco Ravanelli 2020

  • Sung-Lin Yeh 2020

摘要

AlivedHypotheses

该类处理解码过程中假设的数据。

S2SBaseSearcher

S2SBaseSearcher 类,用于 seq2seq 模型的其他解码方法的继承。

S2SBeamSearcher

该类实现了 seq2seq 模型的束搜索算法。

S2SGreedySearcher

该类实现了贪婪解码方法的通用前向传递。

S2SHFTextBasedBeamSearcher

该类实现了文本类 HF seq2seq 模型的束搜索解码,例如 mBART 或 NLLB。

S2SRNNBeamSearcher

该类实现了 AttentionalRNNDecoder (speechbrain/nnet/RNN.py) 的束搜索解码。

S2SRNNGreedySearcher

该类实现了 AttentionalRNNDecoder (speechbrain/nnet/RNN.py) 的贪婪解码。

S2STransformerBeamSearcher

该类实现了 Transformer 的束搜索解码。

S2STransformerGreedySearcher

该类实现了 Transformer 的贪婪解码。

S2SWhisperBeamSearcher

该类实现了 OpenAI 在 https://cdn.openai.com/papers/whisper.pdf 中创建的 Whisper 神经网络的束搜索解码。

S2SWhisperGreedySearcher

该类实现了 OpenAI 在 https://cdn.openai.com/papers/whisper.pdf 中创建的 Whisper 神经网络的贪婪解码。

参考

class speechbrain.decoders.seq2seq.AlivedHypotheses(alived_seq, alived_log_probs, sequence_scores)[source]

基类: Module

该类处理解码过程中假设的数据。

参数:
  • alived_seq (torch.Tensor) – 每个假设的 token 序列。

  • alived_log_probs (torch.Tensor) – 每个假设的每个 token 的对数概率。

  • sequence_scores (torch.Tensor) – 每个假设的对数概率之和。

class speechbrain.decoders.seq2seq.S2SBaseSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]

基类: Module

S2SBaseSearcher 类,用于 seq2seq 模型的其他解码方法的继承。

参数:
  • bos_index (int) – 序列开始 (bos) token 的索引。

  • eos_index (int) – 序列结束 (eos) token 的索引。

  • min_decode_ratio (float) – 最小解码步数与编码器状态长度的比例。

  • max_decode_ratio (float) – 最大解码步数与编码器状态长度的比例。

forward(enc_states, wav_len)[source]

此方法应实现解码方法的前向算法。

参数:
  • enc_states (torch.Tensor) – 解码时使用的预计算编码器状态。(例如,需要注意的编码语音表示)。

  • wav_len (torch.Tensor) – SpeechBrain 风格的相对长度。

返回:

  • hyps – 预测的 token,作为列表的列表,或者如果 return_topk 为 True,则为形状为 (batch, topk, token_id 序列的最大长度) 的 Tensor。

  • top_lengths – 批次中每个 topk 序列的长度。

  • top_scores – topk 假设的最终得分。

  • top_log_probs – 每个假设的对数概率。

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

此方法应实现自回归模型中的一步前向操作。

参数:
  • inp_tokens (torch.Tensor) – 当前步骤的输入 tensor。

  • memory (No limit) – 此步骤的输入内存变量。(例如,RNN 隐藏状态)。

  • enc_states (torch.Tensor) – 需要注意的编码器状态。

  • enc_lens (torch.Tensor) – 每个 enc_states 序列的实际长度。

返回:

  • log_probs (torch.Tensor) – 当前步骤输出的对数概率。

  • memory (No limit) – 此步骤生成的内存变量。(例如,RNN 隐藏状态)。

  • attn (torch.Tensor) – 用于惩罚的注意力权重。

reset_mem(batch_size, device)[source]

此方法应实现 seq2seq 模型内存变量的重置。例如,将零向量初始化为初始隐藏状态。

参数:
  • batch_size (int) – 批次大小。

  • device (torch.device) – 放置初始变量的设备。

返回:

memory – 初始内存变量。

返回类型:

无限制

change_max_decoding_length(min_decode_steps, max_decode_steps)[source]

设置需要注意的 enc_states 的最小/最大长度。

set_n_out()[source]

设置输出 token 的数量。如果 fc 层嵌入在模型中(例如 Whisper),则覆盖此函数。

class speechbrain.decoders.seq2seq.S2SGreedySearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]

基类: S2SBaseSearcher

该类实现了贪婪解码方法的通用前向传递。另请参阅 S2SBaseSearcher()。

forward(enc_states, wav_len)[source]

此方法执行贪婪搜索。

参数:
  • enc_states (torch.Tensor) – 解码时使用的预计算编码器状态。(例如,需要注意的编码语音表示)。

  • wav_len (torch.Tensor) – SpeechBrain 风格的相对长度。

返回:

  • hyps (List[List[int]]) – 包含假设的列表。

  • top_lengths (torch.Tensor (batch)) – 此 tensor 包含每个假设的长度。

  • top_scores (torch.Tensor (batch)) – 每个假设的得分。

  • top_log_probs (torch.Tensor (batch, max length of token_id sequences)) – 每个假设的对数概率。

class speechbrain.decoders.seq2seq.S2STransformerGreedySearcher(modules, temperature=0.0, **kwargs)[source]

基类: S2SGreedySearcher

该类实现了 Transformer 的贪婪解码。

参数:
  • modules (包含以下内容的列表:) –

    modeltorch.nn.Module

    一个 TransformerASR 模型。

    seq_lintorch.nn.Module

    seq2seq 模型的线性输出层。

  • temperature (float) – 解码时使用的温度。

  • **kwargs – 传递给 S2SGreedySearcher 的参数

reset_mem(batch_size, device)[source]

在贪婪搜索期间重置内存所需。

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

在实现的贪婪搜索器中执行一步。

class speechbrain.decoders.seq2seq.S2SWhisperGreedySearcher(model, temperature=0.0, use_kv_cache=True, suppress_blank=True, suppress_tokens='-1', sample_len=None, prefix=None, prompt=None, **kwargs)[source]

基类: S2SGreedySearcher

该类实现了 OpenAI 在 https://cdn.openai.com/papers/whisper.pdf 中创建的 Whisper 神经网络的贪婪解码。

参数:
set_lang_tokens(lang_tokens)[source]

设置解码时使用的语言。

set_task(task)[source]

设置解码时使用的任务。

set_prompt(prompt)[source]

设置解码时使用的提示。

property get_tokens_to_suppress

如果 self.config.suppress_tokens 为 None,获取解码时需要抑制的 token。

reset_mem(batch_size, device)[source]

此方法将搜索期间的第一个 token 设置为 decoder_input_tokens。

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

在实现的贪婪搜索器中执行一步。

class speechbrain.decoders.seq2seq.S2SRNNGreedySearcher(embedding, decoder, linear, temperature=0.0, **kwargs)[source]

基类: S2SGreedySearcher

该类实现了 AttentionalRNNDecoder (speechbrain/nnet/RNN.py) 的贪婪解码。另请参阅 S2SBaseSearcher() 和 S2SGreedySearcher()。

参数:
  • embedding (torch.nn.Module) – 嵌入层。

  • decoder (torch.nn.Module) – 带注意力的 RNN 解码器。

  • linear (torch.nn.Module) – 线性输出层。

  • temperature (float) – 解码时使用的温度。

  • **kwargs – 参见 S2SBaseSearcher,参数直接传递。

示例

>>> import speechbrain as sb
>>> from speechbrain.decoders import S2SRNNGreedySearcher
>>> emb = torch.nn.Embedding(5, 3)
>>> dec = sb.nnet.RNN.AttentionalRNNDecoder(
...     "gru", "content", 3, 3, 1, enc_dim=7, input_size=3
... )
>>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3)
>>> searcher = S2SRNNGreedySearcher(
...     embedding=emb,
...     decoder=dec,
...     linear=lin,
...     bos_index=0,
...     eos_index=1,
...     min_decode_ratio=0,
...     max_decode_ratio=1,
... )
>>> batch_size = 2
>>> enc = torch.rand([batch_size, 6, 7])
>>> wav_len = torch.ones([batch_size])
>>> top_hyps, top_lengths, _, _ = searcher(enc, wav_len)
reset_mem(batch_size, device)[source]

执行贪婪搜索时,将隐藏状态 (hs) 和上下文向量 (c) 作为内存保存。

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

在实现的贪婪搜索器中执行一步。

class speechbrain.decoders.seq2seq.S2SBeamSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio, beam_size, scorer=None, return_topk=False, topk=1, using_eos_threshold=True, eos_threshold=1.5, length_normalization=True, using_max_attn_shift=False, max_attn_shift=60, minus_inf=-1e+20)[source]

基类: S2SBaseSearcher

该类实现了 seq2seq 模型的束搜索算法。另请参阅 S2SBaseSearcher()。

参数:
  • bos_index (int) – 序列开始 token 的索引。

  • eos_index (int) – 序列结束 token 的索引。

  • min_decode_ratio (float) – 最小解码步数与编码器状态长度的比例。

  • max_decode_ratio (float) – 最大解码步数与编码器状态长度的比例。

  • beam_size (int) – 束的大小。

  • scorer (speechbrain.decoders.scorers.ScorerBuilder) – Scorer 实例。默认值: None。

  • return_topk (bool) – 是否返回 topk 假设。topk 假设将被填充到相同长度。默认值: False。

  • topk (int) – 如果 return_topk 为 True,则返回 topk 假设。默认值: 1。

  • using_eos_threshold (bool) – 是否使用 eos 阈值。默认值: True。

  • eos_threshold (float) – eos token 的阈值系数。默认值: 1.5。参见参考文献 3.1.2: https://arxiv.org/abs/1904.02619

  • length_normalization (bool) – 是否将得分除以长度。默认值: True。

  • using_max_attn_shift (bool) – 是否使用 max_attn_shift 约束。默认值: False。

  • max_attn_shift (int) – 束搜索将阻止注意力偏移超过 max_attn_shift 的束。默认值: 60。参考文献: https://arxiv.org/abs/1904.02619

  • minus_inf (float) – 负无穷大的值,用于阻止搜索的某些路径。默认值: -1e20。

init_hypotheses()[source]

此方法初始化 AlivedHypotheses 对象。

返回:

用初始值填充的活跃假设。

返回类型:

AlivedHypotheses

init_beam_search_data(enc_states, wav_len)[source]

初始化束搜索数据。

参数:
  • enc_states (torch.Tensor) – 需要注意的编码器状态。

  • wav_len (torch.Tensor) – 每个 enc_states 序列的实际长度。

返回:

  • alived_hyps (AlivedHypotheses) – 活跃假设。

  • inp_tokens (torch.Tensor) – 当前步骤的输入 tensor。

  • log_probs (torch.Tensor) – 当前步骤输出的对数概率。

  • eos_hyps_and_log_probs_scores (list) – 生成的假设(已到达 eos 的假设)和对数概率得分。

  • memory (No limit) – 此步骤生成的内存变量。

  • scorer_memory (No limit) – 此步骤生成的 scorer 内存变量。

  • attn (torch.Tensor) – 注意力权重。

  • prev_attn_peak (torch.Tensor) – 先前注意力峰值位置。

  • enc_states (torch.Tensor) – 需要注意的编码器状态。

  • enc_lens (torch.Tensor) – 每个 enc_states 序列的实际长度。

search_step(alived_hyps, inp_tokens, log_probs, eos_hyps_and_log_probs_scores, memory, scorer_memory, attn, prev_attn_peak, enc_states, enc_lens, step)[source]

搜索下一步最有可能的 token。

参数:
  • alived_hyps (AlivedHypotheses) – 活跃假设。

  • inp_tokens (torch.Tensor) – 当前步骤的输入 tensor。

  • log_probs (torch.Tensor) – 当前步骤输出的对数概率。

  • eos_hyps_and_log_probs_scores (list) – 生成的假设(已到达 eos 的假设)和对数概率得分。

  • memory (No limit) – 此步骤的输入内存变量。(例如,RNN 隐藏状态)。

  • scorer_memory (No limit) – 此步骤的输入 scorer 内存变量。(例如,RNN 隐藏状态)。

  • attn (torch.Tensor) – 注意力权重。

  • prev_attn_peak (torch.Tensor) – 先前注意力峰值位置。

  • enc_states (torch.Tensor) – 需要注意的编码器状态。

  • enc_lens (torch.Tensor) – 每个 enc_states 序列的实际长度。

  • step (int) – 当前解码步骤。

返回:

  • alived_hyps (AlivedHypotheses) – 活跃假设。

  • inp_tokens (torch.Tensor) – 当前步骤的输入 tensor。

  • log_probs (torch.Tensor) – 当前步骤输出的对数概率。

  • eos_hyps_and_log_probs_scores (list) – 生成的假设(已到达 eos 的假设)和对数概率得分。

  • memory (No limit) – 此步骤生成的内存变量。

  • scorer_memory (No limit) – 此步骤生成的 scorer 内存变量。

  • attn (torch.Tensor) – 注意力权重。

  • prev_attn_peak (torch.Tensor) – 先前注意力峰值位置。

  • scores (torch.Tensor) – 当前步骤输出的得分。

forward(enc_states, wav_len)[source]

应用束搜索并返回预测的 token。

参数:
  • enc_states (torch.Tensor) – 需要注意的编码器状态。

  • wav_len (torch.Tensor) – 每个 enc_states 序列的实际长度。

返回:

  • hyps (list) – 预测的 token。

  • best_lens (torch.Tensor) – 每个预测 token 的长度。

  • best_scores (torch.Tensor) – 每个预测 token 的得分。

  • best_log_probs (torch.Tensor) – 每个预测 token 的对数概率。

permute_mem(memory, index)[source]

此方法置换 seq2seq 模型内存,以使内存索引与当前输出同步。

参数:
  • memory (No limit) – 需要置换的内存变量。

  • index (torch.Tensor) – 上一个路径的索引。

返回类型:

正在置换的内存变量。

class speechbrain.decoders.seq2seq.S2SRNNBeamSearcher(embedding, decoder, linear, temperature=1.0, **kwargs)[source]

基类: S2SBeamSearcher

该类实现了 AttentionalRNNDecoder (speechbrain/nnet/RNN.py) 的束搜索解码。另请参阅 S2SBaseSearcher(),S2SBeamSearcher()。

参数:
  • embedding (torch.nn.Module) – 嵌入层。

  • decoder (torch.nn.Module) – 带注意力的 RNN 解码器。

  • linear (torch.nn.Module) – 线性输出层。

  • temperature (float) – 应用于 softmax 的温度因子。它改变了概率分布,当 T>1 时更平缓,当 T<1 时更尖锐。

  • **kwargs – 参见 S2SBeamSearcher,参数直接传递。

示例

>>> import speechbrain as sb
>>> vocab_size = 5
>>> emb = torch.nn.Embedding(vocab_size, 3)
>>> dec = sb.nnet.RNN.AttentionalRNNDecoder(
...     "gru", "content", 3, 3, 1, enc_dim=7, input_size=3
... )
>>> lin = sb.nnet.linear.Linear(n_neurons=vocab_size, input_size=3)
>>> coverage_scorer = sb.decoders.scorer.CoverageScorer(vocab_size)
>>> scorer = sb.decoders.scorer.ScorerBuilder(
...     full_scorers = [coverage_scorer],
...     partial_scorers = [],
...     weights= dict(coverage=1.5)
... )
>>> searcher = S2SRNNBeamSearcher(
...     embedding=emb,
...     decoder=dec,
...     linear=lin,
...     bos_index=4,
...     eos_index=4,
...     min_decode_ratio=0,
...     max_decode_ratio=1,
...     beam_size=2,
...     scorer=scorer,
... )
>>> batch_size = 2
>>> enc = torch.rand([batch_size, 6, 7])
>>> wav_len = torch.ones([batch_size])
>>> hyps, _, _, _ = searcher(enc, wav_len)
reset_mem(batch_size, device)[source]

在束搜索期间重置内存所需。

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

在实现的贪婪搜索器中执行一步。

permute_mem(memory, index)[source]

Beamsearch 期间的内存置换。

class speechbrain.decoders.seq2seq.S2STransformerBeamSearcher(modules, temperature=1.0, **kwargs)[source]

基类: S2SBeamSearcher

此类实现了 Transformer 的束搜索解码。另请参阅 S2SBaseSearcher(), S2SBeamSearcher().

参数:
  • modules (包含以下内容的列表:) –

    modeltorch.nn.Module

    Transformer 模型。

    seq_lintorch.nn.Module

    一个线性输出层。

  • temperature (float) – 应用于 softmax 的温度因子。它改变了概率分布,当 T>1 时更平缓,当 T<1 时更尖锐。

  • **kwargs – 传递给 S2SBeamSearcher 的参数

示例

>>> from speechbrain.nnet.linear import Linear
>>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
>>> from speechbrain.decoders import S2STransformerBeamSearcher
>>> batch_size=8
>>> n_channels=6
>>> input_size=40
>>> d_model=128
>>> tgt_vocab=140
>>> src = torch.rand([batch_size, n_channels, input_size])
>>> tgt = torch.randint(0, tgt_vocab, [batch_size, n_channels])
>>> net = TransformerASR(
...    tgt_vocab, input_size, d_model, 8, 1, 1, 1024, activation=torch.nn.GELU
... )
>>> ctc_lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab)
>>> lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab)
>>> searcher = S2STransformerBeamSearcher(
...     modules=[net, lin],
...     bos_index=1,
...     eos_index=2,
...     min_decode_ratio=0.0,
...     max_decode_ratio=1.0,
...     using_eos_threshold=False,
...     beam_size=7,
...     temperature=1.15,
... )
>>> enc, dec = net.forward(src, tgt)
>>> hyps, _, _, _  = searcher(enc, torch.ones(batch_size))
reset_mem(batch_size, device)[source]

在束搜索期间重置内存所需。

permute_mem(memory, index)[source]

Beamsearch 期间的内存置换。

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

在实现的贪婪搜索器中执行一步。

class speechbrain.decoders.seq2seq.S2SWhisperBeamSearcher(module, temperature=1.0, use_kv_cache=True, suppress_blank=True, suppress_tokens='-1', sample_len=None, prefix=None, prompt=None, **kwargs)[source]

基类: S2SBeamSearcher

该类实现了 OpenAI 在 https://cdn.openai.com/papers/whisper.pdf 中创建的 Whisper 神经网络的束搜索解码。

束搜索是有状态的,这意味着一些变量存储在搜索器中。如果您想在不同上下文中使用搜索器,应确保相应地更新这些变量。

参数:
  • module (包含以下内容的列表:) –

    modeltorch.nn.Module

    一个 Whisper 模型。它应该有一个 decode() 方法。

  • temperature (float) – 解码时使用的温度。

  • use_kv_cache (bool (默认值: True)) – 是否使用 key-value 缓存。

  • suppress_blank (bool (默认值: True)) – 这将抑制空白输出。

  • suppress_tokens (strlist (默认值: "-1")) – 需要抑制的 token ID 列表(或逗号分隔的 token ID),“-1” 将抑制 model.non_speech_tokens() 中定义的一组符号

  • sample_len (int (默认值: None)) – 最大采样 token 数。

  • prefix (strlist (默认值: None)) – 添加到输入 token 的前缀。参见: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051

  • prompt (strlist (默认值: None)) – 添加到输入 token 的提示。参见: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051

  • **kwargs – 参见 S2SBeamSearcher,参数直接传递。

set_lang_tokens(lang_tokens)[source]

设置解码时使用的语言。

set_task(task)[source]

设置解码时使用的任务。

set_prompt(prompt)[source]

设置解码时使用的提示。

property get_tokens_to_suppress

如果 self.config.suppress_tokens 为 None,获取解码时需要抑制的 token。

reset_mem(batch_size, device)[source]

此方法将搜索期间的第一个 token 设置为 decoder_input_tokens。

permute_mem(memory, index)[source]

置换内存。

set_n_out()[source]

设置输出 token 的数量。

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

在实现的贪婪搜索器中执行一步。

class speechbrain.decoders.seq2seq.S2SHFTextBasedBeamSearcher(modules, vocab_size, **kwargs)[source]

基类: S2STransformerBeamSearcher

此类实现了基于文本的 HF seq2seq 模型(如 mBART 或 NLLB)的束搜索解码。它与 S2STransformerBeamSearcher 没有显著差异。这就是为什么它继承了 S2STransformerBeamSearcher。主要区别可能在于,用户希望直接使用基于文本的 HF 模型的 lm_head,而不是创建一个新的投影层 (self.fc = None)。

参数:
  • modules (包含以下内容的列表:) –

    modeltorch.nn.Module

    Transformer 模型。

    seq_lintorch.nn.Module

    一个线性输出层。在此用例中通常设为 None。

  • vocab_size (int) – lm_head 的维度。

  • **kwargs – 传递给 S2SBeamSearcher 的参数

forward_step(inp_tokens, memory, enc_states, enc_lens)[source]

在实现的贪婪搜索器中执行一步。

set_n_out()[source]

设置输出 token 的数量。