speechbrain.processing.NMF 模块

非负矩阵分解

作者
  • Cem Subakan

摘要

函数

NMF_separate_spectra

此函数根据给定的 NMF 模板矩阵分离混合信号。

reconstruct_results

此函数将分离的频谱重构为波形。

spectral_phase

返回复数频谱图的相位。

参考

speechbrain.processing.NMF.spectral_phase(stft)[source]

返回复数频谱图的相位。

参数:

stft (torch.Tensor) – 一个张量,stft 函数的输出。

返回:

phase

返回类型:

torch.Tensor

示例

>>> BS, nfft, T = 10, 20, 300
>>> X_stft = torch.randn(BS, nfft//2 + 1, T, 2)
>>> phase_mix = spectral_phase(X_stft)
speechbrain.processing.NMF.NMF_separate_spectra(Whats, Xmix)[source]

此函数根据给定的 NMF 模板矩阵分离混合信号。

参数:
  • Whats (list) – 此列表包含列表 [W1, W2],其中 W1 W2 分别对应于源 1 和源 2 的 NMF 模板矩阵。W1、W2 的大小为 [nfft/2 + 1, K],其中 nfft 是 STFT 的 fft 大小,K 是 W 中向量(模板)的数量。

  • Xmix (torch.Tensor) – 这是混合的幅度谱。大小为 [BS x T x nfft//2 + 1],其中 BS = 批次大小,nfft = fft 大小,T = 频谱中的时间步数。

返回:

  • X1hat (源 1 的分离频谱) – 大小 = [BS x (nfft/2 +1) x T],其中 BS = 批次大小,nfft = fft 大小,T = 频谱中的时间步数。

  • X2hat (源 2 的分离频谱) – 大小定义与上述相同。

示例

>>> BS, nfft, T = 4, 20, 400
>>> K1, K2 = 10, 10
>>> W1hat = torch.randn(nfft//2 + 1, K1)
>>> W2hat = torch.randn(nfft//2 + 1, K2)
>>> Whats = [W1hat, W2hat]
>>> Xmix = torch.randn(BS, T, nfft//2 + 1)
>>> X1hat, X2hat = NMF_separate_spectra(Whats, Xmix)
speechbrain.processing.NMF.reconstruct_results(X1hat, X2hat, X_stft, sample_rate, win_length, hop_length)[source]

此函数将分离的频谱重构为波形。

参数:
  • X1hat (torch.Tensor) – 源 1 的分离频谱,大小为 [BS, nfft/2 + 1, T],其中 BS = 批次大小,nfft = fft 大小,T = 频谱长度。

  • X2hat (torch.Tensor) – 源 2 的分离频谱,大小为 [BS, nfft/2 + 1, T]。大小定义与 Xhat1 相同。

  • X_stft (torch.Tensor) – 这是混合的幅度谱。大小为 [BS x nfft//2 + 1 x T x 2],其中 BS = 批次大小,nfft = fft 大小,T = 频谱中的时间步数。最后一维表示复数。

  • sample_rate (int) – 我们希望保存结果的采样率(单位:Hz)。

  • win_length (int) – stft 窗口的长度(单位:ms)。

  • hop_length (int) – STFT 窗口的移位长度(单位:ms)。

返回:

  • x1hats (list) – 源 1 的波形列表。

  • x2hats (list) – 源 2 的波形列表。

示例

>>> BS, nfft, T = 10, 512, 16000
>>> sample_rate, win_length, hop_length = 16000, 25, 10
>>> X1hat = torch.randn(BS, nfft//2 + 1, T)
>>> X2hat = torch.randn(BS, nfft//2 + 1, T)
>>> X_stft = torch.randn(BS, nfft//2 + 1, T, 2)
>>> x1hats, x2hats = reconstruct_results(X1hat, X2hat, X_stft, sample_rate, win_length, hop_length)