speechbrain.nnet.transducer.transducer_joint 模块

实现 transducer_joint 的库。

作者

Abdelwahab HEBA 2020

摘要

Transducer_joint

计算 Transcription network (TN) 与 Prediction network (PN) 之间的联合张量

参考

class speechbrain.nnet.transducer.transducer_joint.Transducer_joint(joint_network=None, joint='sum', nonlinearity=<class 'torch.nn.modules.activation.LeakyReLU'>)[source]

基类: Module

计算 Transcription network (TN) 与 Prediction network (PN) 之间的联合张量

参数:
  • joint_network (torch.class (神经网络模块)) – 如果 joint == “concat”,我们将在 TN 和 PN 连接后调用此网络;如果为 None,则不使用此网络。

  • joint (str) – 通过 (“sum” 或 “concat”) 选项连接两个张量。

  • nonlinearity (torch class) – 用于 TN 和 PN 联合后的激活函数。非线性类型 (tanh, relu)。

示例

>>> from speechbrain.nnet.transducer.transducer_joint import Transducer_joint
>>> from speechbrain.nnet.linear import Linear
>>> input_TN = torch.rand(8, 200, 1, 40)
>>> input_PN = torch.rand(8, 1, 12, 40)
>>> joint_network = Linear(input_size=80, n_neurons=80)
>>> TJoint = Transducer_joint(joint_network, joint="concat")
>>> output = TJoint(input_TN, input_PN)
>>> output.shape
torch.Size([8, 200, 12, 80])
init_params(first_input)[source]
参数:

first_input (tensor) – 用于初始化参数的第一个输入。

forward(input_TN, input_PN)[source]

返回输入张量的融合结果。

参数:
  • input_TN (torch.Tensor) – 来自 Transcription Network 的输入。

  • input_PN (torch.Tensor) – 来自 Prediction Network 的输入。

返回类型:

输入张量的融合。