speechbrain.decoders.utils 模块

解码模块的实用函数。

作者
  • Adel Moumen 2023

  • Ju-Chieh Chou 2020

  • Peter Plantinga 2020

  • Mirco Ravanelli 2020

  • Sung-Lin Yeh 2020

摘要

函数

batch_filter_seq2seq_output

batch_size 次调用 filter_seq2seq_output。

filter_seq2seq_output

过滤输出直到第一个 eos 出现(不包含)。

inflate_tensor

此函数沿指定维度将张量扩展指定次数。

mask_by_condition

如果条件为 False,此函数将用 fill_value 掩盖张量中的某些元素。

参考

speechbrain.decoders.utils.inflate_tensor(tensor, times, dim)[source]

此函数沿指定维度将张量扩展指定次数。

参数:
  • tensor (torch.Tensor) – 要扩展的张量。

  • times (int) – 张量将按此次数进行扩展。

  • dim (int) – 要扩展的维度。

返回:

扩展后的张量。

返回类型:

torch.Tensor

示例

>>> tensor = torch.Tensor([[1,2,3], [4,5,6]])
>>> new_tensor = inflate_tensor(tensor, 2, dim=0)
>>> new_tensor
tensor([[1., 2., 3.],
        [1., 2., 3.],
        [4., 5., 6.],
        [4., 5., 6.]])
speechbrain.decoders.utils.mask_by_condition(tensor, cond, fill_value)[source]

如果条件为 False,此函数将用 fill_value 掩盖张量中的某些元素。

参数:
  • tensor (torch.Tensor) – 要掩盖的张量。

  • cond (torch.BoolTensor) – 此张量必须与 tensor 大小相同。每个元素表示是否保留 tensor 中的值。

  • fill_value (float) – 用于填充被掩盖元素的值。

返回:

掩盖后的张量。

返回类型:

torch.Tensor

示例

>>> tensor = torch.Tensor([[1,2,3], [4,5,6]])
>>> cond = torch.BoolTensor([[True, True, False], [True, False, False]])
>>> mask_by_condition(tensor, cond, 0)
tensor([[1., 2., 0.],
        [4., 0., 0.]])
speechbrain.decoders.utils.batch_filter_seq2seq_output(prediction, eos_id=-1)[source]

batch_size 次调用 filter_seq2seq_output。

参数:
  • prediction (list of torch.Tensor) – 包含 seq2seq 系统预测的整数输出的列表。

  • eos_id (int, string) – eos 的 ID。

返回:

seq2seq 模型预测的输出。

返回类型:

list

示例

>>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])]
>>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4)
>>> predictions
[[1, 2, 3], [2, 3]]
speechbrain.decoders.utils.filter_seq2seq_output(string_pred, eos_id=-1)[source]

过滤输出直到第一个 eos 出现(不包含)。

参数:
  • string_pred (list) – 包含 seq2seq 系统预测的字符串/整数输出的列表。

  • eos_id (int, string) – eos 的 ID。

返回:

seq2seq 模型预测的输出。

返回类型:

list

示例

>>> string_pred = ['a','b','c','d','eos','e']
>>> string_out = filter_seq2seq_output(string_pred, eos_id='eos')
>>> string_out
['a', 'b', 'c', 'd']