speechbrain.nnet.quantisers 模块

Gumbel Softmax 实现,支持多个组。

作者
  • Rudolf A. Braun 2022

摘要

GumbelVectorQuantizer

使用 Gumbel Softmax 的向量量化。

RandomProjectionQuantizer

使用投影和随机初始化码本的向量量化,例如对于 BEST-RQ 等模型很有用。

参考

class speechbrain.nnet.quantisers.GumbelVectorQuantizer(input_dim, num_vars, temp_tuple, groups, vq_dim)[源码]

基类:Module

使用 gumbel softmax 的向量量化。复制自 fairseq 实现。 :param input_dim: 输入维度(通道)。 :type input_dim: int :param num_vars: 每组的量化向量数量。 :type num_vars: int :param temp_tuple: 训练温度。应为包含 3 个元素的元组:(开始值, 结束值, 衰减因子)。 :type temp_tuple: float :param groups: 向量量化的组数量。 :type groups: int :param vq_dim: 结果量化向量的维度。 :type vq_dim: int

示例

>>> quantiser = GumbelVectorQuantizer(128, 100, (2.0, 0.25, 0.999995,), 2, 50 )
>>> inputs = torch.rand(10, 12, 128)
>>> output = quantiser(inputs)
>>> output["x"].shape
torch.Size([10, 12, 50])
update_temp(steps)[源码]

根据当前步数更新温度

forward(x)[源码]

前向传播潜在向量以获得量化输出

class speechbrain.nnet.quantisers.RandomProjectionQuantizer(input_dim, cb_dim, cb_vocab)[源码]

基类:Module

使用投影和随机初始化码本的向量量化,例如对于 BEST-RQ 等模型很有用。

输出是输入中每个时间步在码本中最接近的码的索引。

参考:https://arxiv.org/pdf/2202.01855

参数
  • input_dim (int) – 输入维度(通道)。

  • cb_dim (int) – 码本中每个码的尺寸。

  • cb_vocab (int) – 码本中的码数量

示例

>>> quantiser = RandomProjectionQuantizer(16, 16, 32)
>>> inputs = torch.rand(10, 12, 16)
>>> output = quantiser(inputs)
>>> output.shape
torch.Size([10, 12])
forward(x)[源码]

前向传播潜在向量以获得量化输出