speechbrain.nnet.loss.guidedattn_loss 模块
引导注意力损失的实现
此损失函数可用于加速模型的训练,在这些模型中,输入和输出之间的对应关系大致是线性的,并且注意力对齐预计近似对角线,例如 Grapheme-to-Phoneme 和 Text-to-Speech
作者 * Artem Ploujnikov 2021
总结
类
一种损失实现,它强制注意力矩阵接近对角线,对远离对角线区域的注意力施加逐渐增大的惩罚。 |
参考
- class speechbrain.nnet.loss.guidedattn_loss.GuidedAttentionLoss(sigma=0.2)[source]
基类:
Module
一种损失实现,它强制注意力矩阵接近对角线,对远离对角线区域的注意力施加逐渐增大的惩罚。这对于期望输出序列与输入序列紧密对应的序列到序列模型很有用,例如 TTS 或 G2P
https://arxiv.org/abs/1710.08969
该实现受到 R9Y9 DeepVoice3 模型 https://github.com/r9y9/deepvoice3_pytorch 的启发
它应该大致等效,但已完全向量化。
- 参数:
sigma (float) – 引导注意力权重
示例
注意:在实际场景中,input_lengths 和 target_lengths 来自数据批次,而 alignments 来自模型 >>> import torch >>> from speechbrain.nnet.loss.guidedattn_loss import GuidedAttentionLoss >>> loss = GuidedAttentionLoss(sigma=0.2) >>> input_lengths = torch.tensor([2, 3]) >>> target_lengths = torch.tensor([3, 4]) >>> alignments = torch.tensor( … [ … [ … [0.8, 0.2, 0.0], … [0.4, 0.6, 0.0], … [0.2, 0.8, 0.0], … [0.0, 0.0, 0.0], … ], … [ … [0.6, 0.2, 0.2], … [0.1, 0.7, 0.2], … [0.3, 0.4, 0.3], … [0.2, 0.3, 0.5], … ], … ] … ) >>> loss(alignments, input_lengths, target_lengths) tensor(0.1142)
- forward(attention, input_lengths, target_lengths, max_input_len=None, max_target_len=None)[source]
计算单个批次的引导注意力损失
- 参数:
attention (torch.Tensor) – 填充后的注意力/对齐矩阵(批次,目标,输入)
input_lengths (torch.tensor) – 输入长度的张量(批次,长度)
target_lengths (torch.tensor) – 目标长度的张量(批次,长度)
max_input_len (int) – 最大输入长度 - 可选,如果未计算,将设置为 target_lengths 的最大值。在使用数据并行时,可能需要显式设置它
max_target_len (int) – 最大目标长度 - 可选,如果未计算,将设置为 target_lengths 的最大值。在使用数据并行时,可能需要显式设置它
- 返回:
loss – 包含损失值的单元素张量
- 返回类型:
torch.Tensor