speechbrain.nnet.hypermixing 模块
该模块通过 HyperMixing 混合来自不同 token 的信息。它可以被视为 (自) 注意力的线性时间直接替代品。
来源: https://arxiv.org/abs/2203.03691
- 作者
Florian Mai 2023
Juan Pablo Zuluaga 2023
摘要
类
该类实现了多头 HyperMixing。 |
|
该类实现了 HyperNetwork。 |
|
实现了 MultiHead HyperMixer 或 HyperConformer 的类。 |
参考
- class speechbrain.nnet.hypermixing.HyperMixing(input_output_dim: int, hypernet_size: int, tied: bool = False, num_heads: int = 1, fix_tm_hidden_size: bool = False, max_length: int = 3000)[source]
基类:
Module
该类实现了多头 HyperMixing。它是 HyperMixer 中的 token 混合组件的实现,HyperMixer 是自注意力的线性时间直接替代品。与原始 HyperMixer 不同,该模块支持多个头,这提高了模型的表达能力,同时减少了参数数量。
参考:https://arxiv.org/abs/2203.03691
- 参数:
input_output_dim (int) – keys、queries 和 values 中的特征数量
hypernet_size (int) – 决定 token 混合 MLP 隐藏层的大小。
tied (bool) – 如果为 True,则 token 混合 MLP 生成的权重矩阵是 tied(绑定)的。
num_heads (int) – 并行 token 混合 MLP 的数量。
fix_tm_hidden_size (bool) – 如果为 True,则隐藏层大小等于 hypernet_size,而不是 hypernet_size / num_heads。
max_length (int) – 最大输入 token 数量。用于生成足够大的位置嵌入。
示例
>>> import torch >>> inputs = torch.rand([8, 60, 512]) >>> net = HyperMixing(512, 2048, num_heads=8) >>> outputs, attn = net(inputs, inputs, inputs) >>> outputs.shape torch.Size([8, 60, 512])
- forward(query, key, value, attn_mask: Tensor | None = None, key_padding_mask: Tensor | None = None, return_attn_weights: bool | None = True, pos_embs: Tensor | None = None)[source]
为了在 SpeechBrain 中保持兼容性,该方法的签名特意与 sb.nnet.attention.MultiHeadAttention 保持一致。
注意:key, value, attn_mask 和 pos_embs 没有效果。Query 用于所有这三者。因此,目前该模块应仅用于替换自注意力。
- 参数:
query (torch.Tensor) – (B, L, E),其中 L 是目标序列长度,B 是批量大小,E 是嵌入维度。
key (torch.Tensor) – (B, S, E),其中 S 是源序列长度,B 是批量大小,E 是嵌入维度。目前未使用。
value (torch.Tensor) – (B, S, E),其中 S 是源序列长度,B 是批量大小,E 是嵌入维度。目前未使用。
attn_mask (torch.Tensor, optional) – 注意:目前没有效果。
key_padding_mask (torch.Tensor, optional) – (B, S),其中 B 是批量大小,S 是源序列长度。如果提供的是 ByteTensor,则非零位置将被忽略,而零位置将保持不变。如果提供的是 BoolTensor,则值为 True 的位置将被忽略,而值为 False 的位置将保持不变。
return_attn_weights (torch.Tensor, optional) – 注意:目前没有效果。
pos_embs (torch.Tensor, optional) – 注意:目前没有效果。
输出
-------
attn_output (torch.Tensor) – (B, L, E),其中 L 是目标序列长度,B 是批量大小,E 是嵌入维度。
attn_output_weights (torch.Tensor) – (B, L, S),其中 B 是批量大小,L 是目标序列长度,S 是源序列长度。注意:始终返回全零。
- class speechbrain.nnet.hypermixing.HyperNetwork(input_output_dim: int, hypernet_size: int, tied=False, num_heads=1, keep_output_size=True)[source]
基类:
Module
该类实现了 HyperNetwork。这是一种使用一个网络(也称为超网络)为另一个网络生成权重的方法。在这里,它用于生成线性层的标签。
参考:https://arxiv.org/abs/1609.09106
- 参数: