speechbrain.lobes.models.PIQ 模块
该文件实现了通过量化进行事后解释所需类和函数。
作者 * Cem Subakan 2023 * Francesco Paissan 2023
摘要
类
此类实现了一个卷积编码器,用于从日志频谱中提取分类嵌入。 |
|
此类实现了残差块。 |
|
实现了 VQ 字典。 |
|
此类定义了向量量化的 forward 方法。 |
|
此类定义了向量量化的 forward 方法。 |
|
此类从 FocalNet 分类器的表示中重建对数功率频谱图。 |
|
此类从 ViT 分类器的表示中重建对数功率频谱图。 |
|
此类从分类器的表示中重建对数功率频谱图。 |
函数
此类返回一个二进制矩阵,指示给定标签数组在 VQ 字典中的无关区域 |
|
对网络权重应用 Xavier 初始化。 |
参考
- speechbrain.lobes.models.PIQ.get_irrelevant_regions(labels, K, num_classes, N_shared=5, stage='TRAIN')[源代码]
此类返回一个二进制矩阵,指示给定标签数组在 VQ 字典中的无关区域
- 参数:
- 返回:
irrelevant_regions
- 返回类型:
torch.Tensor
示例
>>> labels = torch.Tensor([1, 0, 2]) >>> irrelevant_regions = get_irrelevant_regions(labels, 20, 3, 5) >>> print(irrelevant_regions.shape) torch.Size([3, 20])
- class speechbrain.lobes.models.PIQ.VectorQuantization(*args, **kwargs)[源代码]
基类:
Function
此类定义了向量量化的 forward 方法。由于 VQ 不可微分,如果在调用
.grad()
时会返回 RuntimeError。有关 VQ 操作的 Straight-Through 梯度估计,请参阅VectorQuantizationStraightThrough
。- static forward(ctx, inputs, codebook, labels=None, num_classes=10, activate_class_partitioning=True, shared_keys=10, training=True)[源代码]
使用
codebook
作为 VQ 字典对向量input
应用 VQ。- 参数:
ctx (torch context) – 用于存储反向传播信息的上下文对象。
inputs (torch.Tensor) – 要量化的隐藏表示。期望形状为
torch.Size([B, W, H, C])
。codebook (torch.Tensor) – 用于量化的 VQ 字典。期望形状为
torch.Size([K, C])
,其中 K 是字典元素数量。labels (torch.Tensor) – 分类标签。用于定义无关区域并根据预测类别划分潜在空间。形状应为
torch.Size([B])
。num_classes (int) – 可能的类别数
activate_class_partitioning (bool) –
True
表示应为不同类别进行潜在空间量化。shared_keys (int) – 类别间共享的键数。
training (bool) –
True
表示阶段为 TRAIN。
- 返回:
量化表示的 codebook 索引
- 返回类型:
torch.Tensor
示例
>>> inputs = torch.ones(3, 14, 25, 256) >>> codebook = torch.randn(1024, 256) >>> labels = torch.Tensor([1, 0, 2]) >>> print(VectorQuantization.apply(inputs, codebook, labels).shape) torch.Size([3, 14, 25])
- class speechbrain.lobes.models.PIQ.VectorQuantizationStraightThrough(*args, **kwargs)[源代码]
基类:
Function
此类定义了向量量化的 forward 方法。由于 VQ 不可微分,它使用 https://arxiv.org/abs/1711.00937 中的 straight-through 估计来近似梯度。
- static forward(ctx, inputs, codebook, labels=None, num_classes=10, activate_class_partitioning=True, shared_keys=10, training=True)[源代码]
使用
codebook
作为 VQ 字典对向量input
应用 VQ,并使用 Straight-Through (id) 近似量化步骤来估计梯度。- 参数:
ctx (torch context) – 用于存储反向传播信息的上下文对象。
inputs (torch.Tensor) – 要量化的隐藏表示。期望形状为
torch.Size([B, W, H, C])
。codebook (torch.Tensor) – 用于量化的 VQ 字典。期望形状为
torch.Size([K, C])
,其中 K 是字典元素数量。labels (torch.Tensor) – 分类标签。用于定义无关区域并根据预测类别划分潜在空间。形状应为
torch.Size([B])
。num_classes (int) – 可能的类别数
activate_class_partitioning (bool) –
True
表示应为不同类别进行潜在空间量化。shared_keys (int) – 类别间共享的键数。
training (bool) –
True
表示阶段为 TRAIN。
- 返回:
量化表示和量化表示的 codebook 索引
- 返回类型:
示例
>>> inputs = torch.ones(3, 14, 25, 256) >>> codebook = torch.randn(1024, 256) >>> labels = torch.Tensor([1, 0, 2]) >>> quant, quant_ind = VectorQuantizationStraightThrough.apply(inputs, codebook, labels) >>> print(quant.shape, quant_ind.shape) torch.Size([3, 14, 25, 256]) torch.Size([1050])
- static backward(ctx, grad_output, grad_indices, labels=None, num_classes=None, activate_class_partitioning=True, shared_keys=10, training=True)[源代码]
估计假设向量量化为恒等函数的梯度。(https://arxiv.org/abs/1711.00937)
- class speechbrain.lobes.models.PIQ.Conv2dEncoder_v2(dim=256)[源代码]
基类:
Module
此类实现了一个卷积编码器,用于从日志频谱中提取分类嵌入。
- 参数:
dim (int) – 提取的嵌入的通道数。
示例
>>> inputs = torch.ones(3, 431, 513) >>> model = Conv2dEncoder_v2() >>> print(model(inputs).shape) torch.Size([3, 256, 26, 32])
- class speechbrain.lobes.models.PIQ.ResBlockAudio(dim)[源代码]
基类:
Module
此类实现了残差块。
- 参数:
dim (int) – 要处理的张量的输入通道数。与残差块的输出通道数匹配。
示例
>>> res = ResBlockAudio(128) >>> x = torch.randn(2, 128, 16, 16) >>> print(x.shape) torch.Size([2, 128, 16, 16])
- class speechbrain.lobes.models.PIQ.VectorQuantizedPSI_Audio(dim=128, K=512, numclasses=50, activate_class_partitioning=True, shared_keys=0, use_adapter=True, adapter_reduce_dim=True)[源代码]
基类:
Module
此类从分类器的表示中重建对数功率频谱图。
- 参数:
示例
>>> psi = VectorQuantizedPSI_Audio(dim=256, K=1024) >>> x = torch.randn(2, 256, 16, 16) >>> labels = torch.Tensor([0, 2]) >>> logspectra, hcat, z_q_x = psi(x, labels) >>> print(logspectra.shape, hcat.shape, z_q_x.shape) torch.Size([2, 1, 257, 257]) torch.Size([2, 256, 8, 8]) torch.Size([2, 256, 8, 8])
- class speechbrain.lobes.models.PIQ.VectorQuantizedPSIFocalNet_Audio(dim=1024, **kwargs)[源代码]
-
此类从 FocalNet 分类器的表示中重建对数功率频谱图。
示例
>>> psi = VectorQuantizedPSIFocalNet_Audio(dim=256, K=1024) >>> x = torch.randn(2, 256, 16, 16) >>> labels = torch.Tensor([0, 2]) >>> logspectra, hcat, z_q_x = psi(x, labels) >>> print(logspectra.shape, hcat.shape, z_q_x.shape) torch.Size([2, 1, 495, 593]) torch.Size([2, 256, 8, 8]) torch.Size([2, 256, 8, 8])
- class speechbrain.lobes.models.PIQ.VectorQuantizedPSIViT_Audio(dim=768, **kwargs)[源代码]
-
此类从 ViT 分类器的表示中重建对数功率频谱图。
示例
>>> psi = VectorQuantizedPSIViT_Audio(dim=256, K=1024) >>> x = torch.randn(2, 256, 16, 16) >>> labels = torch.Tensor([0, 2]) >>> logspectra, hcat, z_q_x = psi(x, labels) >>> print(logspectra.shape, hcat.shape, z_q_x.shape) torch.Size([2, 1, 495, 593]) torch.Size([2, 256, 8, 8]) torch.Size([2, 256, 8, 8])
- class speechbrain.lobes.models.PIQ.VQEmbedding(K, D, numclasses=50, activate_class_partitioning=True, shared_keys=0)[源代码]
基类:
Module
实现了 VQ 字典。包装了
VectorQuantization
和VectorQuantizationStraightThrough
。更多详情请参阅特定类。- 参数:
- forward(z_e_x, labels=None)[源代码]
包装了 VectorQuantization。计算输入量化的 VQ 字典索引。请注意,此 forward 步骤不可微分。
- 参数:
z_e_x (torch.Tensor) – 要量化的输入张量。
labels (torch.Tensor) – 输入表示的预测类别(用于潜在空间量化)。
- 返回:
量化表示的 codebook 索引
- 返回类型:
torch.Tensor
示例
>>> inputs = torch.ones(3, 256, 14, 25) >>> codebook = VQEmbedding(1024, 256) >>> labels = torch.Tensor([1, 0, 2]) >>> print(codebook(inputs, labels).shape) torch.Size([3, 14, 25])
- straight_through(z_e_x, labels=None)[源代码]
实现了向量量化,并使用 straight through 近似梯度。
- 参数:
z_e_x (torch.Tensor) – 要量化的输入张量。
labels (torch.Tensor) – 输入表示的预测类别(用于潜在空间量化)。
- 返回:
Straight through 量化表示和量化表示
- 返回类型:
示例
>>> inputs = torch.ones(3, 256, 14, 25) >>> codebook = VQEmbedding(1024, 256) >>> labels = torch.Tensor([1, 0, 2]) >>> quant, quant_ind = codebook.straight_through(inputs, labels) >>> print(quant.shape, quant_ind.shape) torch.Size([3, 256, 14, 25]) torch.Size([3, 256, 14, 25])