speechbrain.decoders.seq2seq 模块
seq2seq 自回归模型的解码方法。
- 作者
Adel Moumen 2022, 2023, 2024
Ju-Chieh Chou 2020
Peter Plantinga 2020
Mirco Ravanelli 2020
Sung-Lin Yeh 2020
摘要
类
该类处理解码过程中假设的数据。 |
|
S2SBaseSearcher 类,用于 seq2seq 模型的其他解码方法的继承。 |
|
该类实现了 seq2seq 模型的束搜索算法。 |
|
该类实现了贪婪解码方法的通用前向传递。 |
|
该类实现了文本类 HF seq2seq 模型的束搜索解码,例如 mBART 或 NLLB。 |
|
该类实现了 AttentionalRNNDecoder (speechbrain/nnet/RNN.py) 的束搜索解码。 |
|
该类实现了 AttentionalRNNDecoder (speechbrain/nnet/RNN.py) 的贪婪解码。 |
|
该类实现了 Transformer 的束搜索解码。 |
|
该类实现了 Transformer 的贪婪解码。 |
|
该类实现了 OpenAI 在 https://cdn.openai.com/papers/whisper.pdf 中创建的 Whisper 神经网络的束搜索解码。 |
|
该类实现了 OpenAI 在 https://cdn.openai.com/papers/whisper.pdf 中创建的 Whisper 神经网络的贪婪解码。 |
参考
- class speechbrain.decoders.seq2seq.AlivedHypotheses(alived_seq, alived_log_probs, sequence_scores)[source]
基类:
Module
该类处理解码过程中假设的数据。
- 参数:
alived_seq (torch.Tensor) – 每个假设的 token 序列。
alived_log_probs (torch.Tensor) – 每个假设的每个 token 的对数概率。
sequence_scores (torch.Tensor) – 每个假设的对数概率之和。
- class speechbrain.decoders.seq2seq.S2SBaseSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]
基类:
Module
S2SBaseSearcher 类,用于 seq2seq 模型的其他解码方法的继承。
- 参数:
- forward(enc_states, wav_len)[source]
此方法应实现解码方法的前向算法。
- 参数:
enc_states (torch.Tensor) – 解码时使用的预计算编码器状态。(例如,需要注意的编码语音表示)。
wav_len (torch.Tensor) – SpeechBrain 风格的相对长度。
- 返回:
hyps – 预测的 token,作为列表的列表,或者如果 return_topk 为 True,则为形状为 (batch, topk, token_id 序列的最大长度) 的 Tensor。
top_lengths – 批次中每个 topk 序列的长度。
top_scores – topk 假设的最终得分。
top_log_probs – 每个假设的对数概率。
- forward_step(inp_tokens, memory, enc_states, enc_lens)[source]
此方法应实现自回归模型中的一步前向操作。
- 参数:
inp_tokens (torch.Tensor) – 当前步骤的输入 tensor。
memory (No limit) – 此步骤的输入内存变量。(例如,RNN 隐藏状态)。
enc_states (torch.Tensor) – 需要注意的编码器状态。
enc_lens (torch.Tensor) – 每个 enc_states 序列的实际长度。
- 返回:
log_probs (torch.Tensor) – 当前步骤输出的对数概率。
memory (No limit) – 此步骤生成的内存变量。(例如,RNN 隐藏状态)。
attn (torch.Tensor) – 用于惩罚的注意力权重。
- reset_mem(batch_size, device)[source]
此方法应实现 seq2seq 模型内存变量的重置。例如,将零向量初始化为初始隐藏状态。
- 参数:
batch_size (int) – 批次大小。
device (torch.device) – 放置初始变量的设备。
- 返回:
memory – 初始内存变量。
- 返回类型:
无限制
- class speechbrain.decoders.seq2seq.S2SGreedySearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio)[source]
基类:
S2SBaseSearcher
该类实现了贪婪解码方法的通用前向传递。另请参阅 S2SBaseSearcher()。
- forward(enc_states, wav_len)[source]
此方法执行贪婪搜索。
- 参数:
enc_states (torch.Tensor) – 解码时使用的预计算编码器状态。(例如,需要注意的编码语音表示)。
wav_len (torch.Tensor) – SpeechBrain 风格的相对长度。
- 返回:
hyps (List[List[int]]) – 包含假设的列表。
top_lengths (torch.Tensor (batch)) – 此 tensor 包含每个假设的长度。
top_scores (torch.Tensor (batch)) – 每个假设的得分。
top_log_probs (torch.Tensor (batch, max length of token_id sequences)) – 每个假设的对数概率。
- class speechbrain.decoders.seq2seq.S2STransformerGreedySearcher(modules, temperature=0.0, **kwargs)[source]
-
该类实现了 Transformer 的贪婪解码。
- 参数:
modules (包含以下内容的列表:) –
- modeltorch.nn.Module
一个 TransformerASR 模型。
- seq_lintorch.nn.Module
seq2seq 模型的线性输出层。
temperature (float) – 解码时使用的温度。
**kwargs – 传递给 S2SGreedySearcher 的参数
- class speechbrain.decoders.seq2seq.S2SWhisperGreedySearcher(model, temperature=0.0, use_kv_cache=True, suppress_blank=True, suppress_tokens='-1', sample_len=None, prefix=None, prompt=None, **kwargs)[source]
-
该类实现了 OpenAI 在 https://cdn.openai.com/papers/whisper.pdf 中创建的 Whisper 神经网络的贪婪解码。
- 参数:
model (HuggingFaceWhisper) – Whisper 模型。
temperature (float) – 解码时使用的温度。
use_kv_cache (bool (默认值: True)) – 是否使用 key-value 缓存。
suppress_blank (bool (默认值: True)) – 这将抑制空白输出。
suppress_tokens (str 或 list (默认值: "-1")) – 需要抑制的 token ID 列表(或逗号分隔的 token ID),“-1” 将抑制
model.non_speech_tokens()
中定义的一组符号sample_len (int (默认值: None)) – 最大采样 token 数。
prefix (str 或 list (默认值: None)) – 添加到输入 token 的前缀。参见: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt (str 或 list (默认值: None)) – 添加到输入 token 的提示。参见: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
**kwargs – 参见 S2SBaseSearcher,参数直接传递。
- property get_tokens_to_suppress
如果 self.config.suppress_tokens 为 None,获取解码时需要抑制的 token。
- class speechbrain.decoders.seq2seq.S2SRNNGreedySearcher(embedding, decoder, linear, temperature=0.0, **kwargs)[source]
-
该类实现了 AttentionalRNNDecoder (speechbrain/nnet/RNN.py) 的贪婪解码。另请参阅 S2SBaseSearcher() 和 S2SGreedySearcher()。
- 参数:
embedding (torch.nn.Module) – 嵌入层。
decoder (torch.nn.Module) – 带注意力的 RNN 解码器。
linear (torch.nn.Module) – 线性输出层。
temperature (float) – 解码时使用的温度。
**kwargs – 参见 S2SBaseSearcher,参数直接传递。
示例
>>> import speechbrain as sb >>> from speechbrain.decoders import S2SRNNGreedySearcher >>> emb = torch.nn.Embedding(5, 3) >>> dec = sb.nnet.RNN.AttentionalRNNDecoder( ... "gru", "content", 3, 3, 1, enc_dim=7, input_size=3 ... ) >>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3) >>> searcher = S2SRNNGreedySearcher( ... embedding=emb, ... decoder=dec, ... linear=lin, ... bos_index=0, ... eos_index=1, ... min_decode_ratio=0, ... max_decode_ratio=1, ... ) >>> batch_size = 2 >>> enc = torch.rand([batch_size, 6, 7]) >>> wav_len = torch.ones([batch_size]) >>> top_hyps, top_lengths, _, _ = searcher(enc, wav_len)
- class speechbrain.decoders.seq2seq.S2SBeamSearcher(bos_index, eos_index, min_decode_ratio, max_decode_ratio, beam_size, scorer=None, return_topk=False, topk=1, using_eos_threshold=True, eos_threshold=1.5, length_normalization=True, using_max_attn_shift=False, max_attn_shift=60, minus_inf=-1e+20)[source]
基类:
S2SBaseSearcher
该类实现了 seq2seq 模型的束搜索算法。另请参阅 S2SBaseSearcher()。
- 参数:
bos_index (int) – 序列开始 token 的索引。
eos_index (int) – 序列结束 token 的索引。
min_decode_ratio (float) – 最小解码步数与编码器状态长度的比例。
max_decode_ratio (float) – 最大解码步数与编码器状态长度的比例。
beam_size (int) – 束的大小。
scorer (speechbrain.decoders.scorers.ScorerBuilder) – Scorer 实例。默认值: None。
return_topk (bool) – 是否返回 topk 假设。topk 假设将被填充到相同长度。默认值: False。
topk (int) – 如果 return_topk 为 True,则返回 topk 假设。默认值: 1。
using_eos_threshold (bool) – 是否使用 eos 阈值。默认值: True。
eos_threshold (float) – eos token 的阈值系数。默认值: 1.5。参见参考文献 3.1.2: https://arxiv.org/abs/1904.02619
length_normalization (bool) – 是否将得分除以长度。默认值: True。
using_max_attn_shift (bool) – 是否使用 max_attn_shift 约束。默认值: False。
max_attn_shift (int) – 束搜索将阻止注意力偏移超过 max_attn_shift 的束。默认值: 60。参考文献: https://arxiv.org/abs/1904.02619
minus_inf (float) – 负无穷大的值,用于阻止搜索的某些路径。默认值: -1e20。
- init_beam_search_data(enc_states, wav_len)[source]
初始化束搜索数据。
- 参数:
enc_states (torch.Tensor) – 需要注意的编码器状态。
wav_len (torch.Tensor) – 每个 enc_states 序列的实际长度。
- 返回:
alived_hyps (AlivedHypotheses) – 活跃假设。
inp_tokens (torch.Tensor) – 当前步骤的输入 tensor。
log_probs (torch.Tensor) – 当前步骤输出的对数概率。
eos_hyps_and_log_probs_scores (list) – 生成的假设(已到达 eos 的假设)和对数概率得分。
memory (No limit) – 此步骤生成的内存变量。
scorer_memory (No limit) – 此步骤生成的 scorer 内存变量。
attn (torch.Tensor) – 注意力权重。
prev_attn_peak (torch.Tensor) – 先前注意力峰值位置。
enc_states (torch.Tensor) – 需要注意的编码器状态。
enc_lens (torch.Tensor) – 每个 enc_states 序列的实际长度。
- search_step(alived_hyps, inp_tokens, log_probs, eos_hyps_and_log_probs_scores, memory, scorer_memory, attn, prev_attn_peak, enc_states, enc_lens, step)[source]
搜索下一步最有可能的 token。
- 参数:
alived_hyps (AlivedHypotheses) – 活跃假设。
inp_tokens (torch.Tensor) – 当前步骤的输入 tensor。
log_probs (torch.Tensor) – 当前步骤输出的对数概率。
eos_hyps_and_log_probs_scores (list) – 生成的假设(已到达 eos 的假设)和对数概率得分。
memory (No limit) – 此步骤的输入内存变量。(例如,RNN 隐藏状态)。
scorer_memory (No limit) – 此步骤的输入 scorer 内存变量。(例如,RNN 隐藏状态)。
attn (torch.Tensor) – 注意力权重。
prev_attn_peak (torch.Tensor) – 先前注意力峰值位置。
enc_states (torch.Tensor) – 需要注意的编码器状态。
enc_lens (torch.Tensor) – 每个 enc_states 序列的实际长度。
step (int) – 当前解码步骤。
- 返回:
alived_hyps (AlivedHypotheses) – 活跃假设。
inp_tokens (torch.Tensor) – 当前步骤的输入 tensor。
log_probs (torch.Tensor) – 当前步骤输出的对数概率。
eos_hyps_and_log_probs_scores (list) – 生成的假设(已到达 eos 的假设)和对数概率得分。
memory (No limit) – 此步骤生成的内存变量。
scorer_memory (No limit) – 此步骤生成的 scorer 内存变量。
attn (torch.Tensor) – 注意力权重。
prev_attn_peak (torch.Tensor) – 先前注意力峰值位置。
scores (torch.Tensor) – 当前步骤输出的得分。
- forward(enc_states, wav_len)[source]
应用束搜索并返回预测的 token。
- 参数:
enc_states (torch.Tensor) – 需要注意的编码器状态。
wav_len (torch.Tensor) – 每个 enc_states 序列的实际长度。
- 返回:
hyps (list) – 预测的 token。
best_lens (torch.Tensor) – 每个预测 token 的长度。
best_scores (torch.Tensor) – 每个预测 token 的得分。
best_log_probs (torch.Tensor) – 每个预测 token 的对数概率。
- class speechbrain.decoders.seq2seq.S2SRNNBeamSearcher(embedding, decoder, linear, temperature=1.0, **kwargs)[source]
基类:
S2SBeamSearcher
该类实现了 AttentionalRNNDecoder (speechbrain/nnet/RNN.py) 的束搜索解码。另请参阅 S2SBaseSearcher(),S2SBeamSearcher()。
- 参数:
embedding (torch.nn.Module) – 嵌入层。
decoder (torch.nn.Module) – 带注意力的 RNN 解码器。
linear (torch.nn.Module) – 线性输出层。
temperature (float) – 应用于 softmax 的温度因子。它改变了概率分布,当 T>1 时更平缓,当 T<1 时更尖锐。
**kwargs – 参见 S2SBeamSearcher,参数直接传递。
示例
>>> import speechbrain as sb >>> vocab_size = 5 >>> emb = torch.nn.Embedding(vocab_size, 3) >>> dec = sb.nnet.RNN.AttentionalRNNDecoder( ... "gru", "content", 3, 3, 1, enc_dim=7, input_size=3 ... ) >>> lin = sb.nnet.linear.Linear(n_neurons=vocab_size, input_size=3) >>> coverage_scorer = sb.decoders.scorer.CoverageScorer(vocab_size) >>> scorer = sb.decoders.scorer.ScorerBuilder( ... full_scorers = [coverage_scorer], ... partial_scorers = [], ... weights= dict(coverage=1.5) ... ) >>> searcher = S2SRNNBeamSearcher( ... embedding=emb, ... decoder=dec, ... linear=lin, ... bos_index=4, ... eos_index=4, ... min_decode_ratio=0, ... max_decode_ratio=1, ... beam_size=2, ... scorer=scorer, ... ) >>> batch_size = 2 >>> enc = torch.rand([batch_size, 6, 7]) >>> wav_len = torch.ones([batch_size]) >>> hyps, _, _, _ = searcher(enc, wav_len)
- class speechbrain.decoders.seq2seq.S2STransformerBeamSearcher(modules, temperature=1.0, **kwargs)[source]
基类:
S2SBeamSearcher
此类实现了 Transformer 的束搜索解码。另请参阅 S2SBaseSearcher(), S2SBeamSearcher().
- 参数:
modules (包含以下内容的列表:) –
- modeltorch.nn.Module
Transformer 模型。
- seq_lintorch.nn.Module
一个线性输出层。
temperature (float) – 应用于 softmax 的温度因子。它改变了概率分布,当 T>1 时更平缓,当 T<1 时更尖锐。
**kwargs – 传递给 S2SBeamSearcher 的参数
示例
>>> from speechbrain.nnet.linear import Linear >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR >>> from speechbrain.decoders import S2STransformerBeamSearcher >>> batch_size=8 >>> n_channels=6 >>> input_size=40 >>> d_model=128 >>> tgt_vocab=140 >>> src = torch.rand([batch_size, n_channels, input_size]) >>> tgt = torch.randint(0, tgt_vocab, [batch_size, n_channels]) >>> net = TransformerASR( ... tgt_vocab, input_size, d_model, 8, 1, 1, 1024, activation=torch.nn.GELU ... ) >>> ctc_lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab) >>> lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab) >>> searcher = S2STransformerBeamSearcher( ... modules=[net, lin], ... bos_index=1, ... eos_index=2, ... min_decode_ratio=0.0, ... max_decode_ratio=1.0, ... using_eos_threshold=False, ... beam_size=7, ... temperature=1.15, ... ) >>> enc, dec = net.forward(src, tgt) >>> hyps, _, _, _ = searcher(enc, torch.ones(batch_size))
- class speechbrain.decoders.seq2seq.S2SWhisperBeamSearcher(module, temperature=1.0, use_kv_cache=True, suppress_blank=True, suppress_tokens='-1', sample_len=None, prefix=None, prompt=None, **kwargs)[source]
基类:
S2SBeamSearcher
该类实现了 OpenAI 在 https://cdn.openai.com/papers/whisper.pdf 中创建的 Whisper 神经网络的束搜索解码。
束搜索是有状态的,这意味着一些变量存储在搜索器中。如果您想在不同上下文中使用搜索器,应确保相应地更新这些变量。
- 参数:
module (包含以下内容的列表:) –
- modeltorch.nn.Module
一个 Whisper 模型。它应该有一个 decode() 方法。
temperature (float) – 解码时使用的温度。
use_kv_cache (bool (默认值: True)) – 是否使用 key-value 缓存。
suppress_blank (bool (默认值: True)) – 这将抑制空白输出。
suppress_tokens (str 或 list (默认值: "-1")) – 需要抑制的 token ID 列表(或逗号分隔的 token ID),“-1” 将抑制
model.non_speech_tokens()
中定义的一组符号sample_len (int (默认值: None)) – 最大采样 token 数。
prefix (str 或 list (默认值: None)) – 添加到输入 token 的前缀。参见: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt (str 或 list (默认值: None)) – 添加到输入 token 的提示。参见: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
**kwargs – 参见 S2SBeamSearcher,参数直接传递。
- property get_tokens_to_suppress
如果 self.config.suppress_tokens 为 None,获取解码时需要抑制的 token。
- class speechbrain.decoders.seq2seq.S2SHFTextBasedBeamSearcher(modules, vocab_size, **kwargs)[source]
基类:
S2STransformerBeamSearcher
此类实现了基于文本的 HF seq2seq 模型(如 mBART 或 NLLB)的束搜索解码。它与 S2STransformerBeamSearcher 没有显著差异。这就是为什么它继承了 S2STransformerBeamSearcher。主要区别可能在于,用户希望直接使用基于文本的 HF 模型的 lm_head,而不是创建一个新的投影层 (self.fc = None)。
- 参数:
modules (包含以下内容的列表:) –
- modeltorch.nn.Module
Transformer 模型。
- seq_lintorch.nn.Module
一个线性输出层。在此用例中通常设为 None。
vocab_size (int) – lm_head 的维度。
**kwargs – 传递给 S2SBeamSearcher 的参数