speechbrain.decoders.utils 模块
解码模块的实用函数。
- 作者
Adel Moumen 2023
Ju-Chieh Chou 2020
Peter Plantinga 2020
Mirco Ravanelli 2020
Sung-Lin Yeh 2020
摘要
函数
batch_size 次调用 filter_seq2seq_output。 |
|
过滤输出直到第一个 eos 出现(不包含)。 |
|
此函数沿指定维度将张量扩展指定次数。 |
|
如果条件为 False,此函数将用 fill_value 掩盖张量中的某些元素。 |
参考
- speechbrain.decoders.utils.inflate_tensor(tensor, times, dim)[source]
此函数沿指定维度将张量扩展指定次数。
- 参数:
- 返回:
扩展后的张量。
- 返回类型:
torch.Tensor
示例
>>> tensor = torch.Tensor([[1,2,3], [4,5,6]]) >>> new_tensor = inflate_tensor(tensor, 2, dim=0) >>> new_tensor tensor([[1., 2., 3.], [1., 2., 3.], [4., 5., 6.], [4., 5., 6.]])
- speechbrain.decoders.utils.mask_by_condition(tensor, cond, fill_value)[source]
如果条件为 False,此函数将用 fill_value 掩盖张量中的某些元素。
- 参数:
tensor (torch.Tensor) – 要掩盖的张量。
cond (torch.BoolTensor) – 此张量必须与 tensor 大小相同。每个元素表示是否保留 tensor 中的值。
fill_value (float) – 用于填充被掩盖元素的值。
- 返回:
掩盖后的张量。
- 返回类型:
torch.Tensor
示例
>>> tensor = torch.Tensor([[1,2,3], [4,5,6]]) >>> cond = torch.BoolTensor([[True, True, False], [True, False, False]]) >>> mask_by_condition(tensor, cond, 0) tensor([[1., 2., 0.], [4., 0., 0.]])
- speechbrain.decoders.utils.batch_filter_seq2seq_output(prediction, eos_id=-1)[source]
batch_size 次调用 filter_seq2seq_output。
- 参数:
- 返回:
seq2seq 模型预测的输出。
- 返回类型:
示例
>>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])] >>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4) >>> predictions [[1, 2, 3], [2, 3]]