
Source code for asteroid.filterbanks.griffin_lim

import torch
import math

from . import Encoder, Decoder, STFTFB  # noqa
from .stft_fb import perfect_synthesis_window
from . import transforms
from ..dsp.consistency import mixture_consistency

[docs]def griffin_lim(mag_specgram, stft_enc, angles=None, istft_dec=None, n_iter=6, momentum=0.9): """Estimates matching phase from magnitude spectogram using the 'fast' Griffin Lim algorithm [1]. Args: mag_specgram (torch.Tensor): (any, dim, ension, freq, frames) as returned by `Encoder(STFTFB)`, the magnitude spectrogram to be inverted. stft_enc (Encoder[STFTFB]): The `Encoder(STFTFB())` object that was used to compute the input `mag_spec`. angles (None or Tensor): Angles to use to initialize the algorithm. If None (default), angles are init with uniform ditribution. istft_dec (None or Decoder[STFTFB]): Optional Decoder to use to get back to the time domain. If None (default), a perfect reconstruction Decoder is built from `stft_enc`. n_iter (int): Number of griffin-lim iterations to run. momentum (float): The momentum of fast Griffin-Lim. Original Griffin-Lim is obtained for momentum=0. Returns: torch.Tensor: estimated waveforms of shape (any, dim, ension, time). Examples: >>> stft = Encoder(STFTFB(n_filters=256, kernel_size=256, stride=128)) >>> wav = torch.randn(2, 1, 8000) >>> spec = stft(wav) >>> masked_spec = spec * torch.sigmoid(torch.randn_like(spec)) >>> mag = transforms.take_mag(masked_spec, -2) >>> est_wav = griffin_lim(mag, stft, n_iter=32) References: [1] Perraudin et al. "A fast Griffin-Lim algorithm," WASPAA 2013. [2] D. W. Griffin and J. S. Lim: "Signal estimation from modified short-time Fourier transform," ASSP 1984. """ # We can create perfect iSTFT from STFT Encoder if istft_dec is None: # Compute window for perfect resynthesis syn_win = perfect_synthesis_window(stft_enc.filterbank.window, stft_enc.stride) istft_dec = Decoder(STFTFB(**stft_enc.get_config(), window=syn_win)) # If no intitial phase is provided initialize uniformly if angles is None: angles = 2 * math.pi * torch.rand_like(mag_specgram, device=mag_specgram.device) else: angles = angles.view(*mag_specgram.shape) # Initialize rebuilt (useful to use momentum) rebuilt = 0.0 for _ in range(n_iter): prev_built = rebuilt # Go to the time domain complex_specgram = transforms.from_mag_and_phase(mag_specgram, angles) waveform = istft_dec(complex_specgram) # And back to TF domain rebuilt = stft_enc(waveform) # Update phase estimates (with momentum) diff = rebuilt - momentum / (1 + momentum) * prev_built angles = transforms.angle(diff) final_complex_spec = transforms.from_mag_and_phase(mag_specgram, angles) return istft_dec(final_complex_spec)
[docs]def misi( mixture_wav, mag_specgrams, stft_enc, angles=None, istft_dec=None, n_iter=6, momentum=0.0, src_weights=None, dim=1, ): """Jointly estimates matching phase from magnitude spectograms using the Multiple Input Spectrogram Inversion (MISI) algorithm [1]. Args: mixture_wav (torch.Tensor): (batch, time) mag_specgrams (torch.Tensor): (batch, n_src, freq, frames) as returned by `Encoder(STFTFB)`, the magnitude spectrograms to be jointly inverted using MISI (modified or not). stft_enc (Encoder[STFTFB]): The `Encoder(STFTFB())` object that was used to compute the input `mag_spec`. angles (None or Tensor): Angles to use to initialize the algorithm. If None (default), angles are init with uniform ditribution. istft_dec (None or Decoder[STFTFB]): Optional Decoder to use to get back to the time domain. If None (default), a perfect reconstruction Decoder is built from `stft_enc`. n_iter (int): Number of MISI iterations to run. momentum (float): Momentum on updates (this argument comes from GriffinLim). Defaults to 0 as it was never proposed anywhere. src_weights (None or torch.Tensor): Consistency weight for each source. Shape needs to be broadcastable to `istft_dec(mag_specgrams)`. We make sure that the weights sum up to 1 along dim `dim`. If `src_weights` is None, compute them based on relative power. dim (int): Axis which contains the sources in `mag_specgrams`. Used for consistency constraint. Returns: torch.Tensor: estimated waveforms of shape (batch, n_src, time). Examples: >>> stft = Encoder(STFTFB(n_filters=256, kernel_size=256, stride=128)) >>> wav = torch.randn(2, 3, 8000) >>> specs = stft(wav) >>> masked_specs = specs * torch.sigmoid(torch.randn_like(specs)) >>> mag = transforms.take_mag(masked_specs, -2) >>> est_wav = misi(wav.sum(1), mag, stft, n_iter=32) References: [1] Gunawan and Sen, "Iterative Phase Estimation for the Synthesis of Separated Sources From Single-Channel Mixtures," in IEEE Signal Processing Letters, 2010. [2] Wang, LeRoux et al. “End-to-End Speech Separation with Unfolded Iterative Phase Reconstruction.” Interspeech 2018 (2018) """ # We can create perfect iSTFT from STFT Encoder if istft_dec is None: # Compute window for perfect resynthesis syn_win = perfect_synthesis_window(stft_enc.filterbank.window, stft_enc.stride) istft_dec = Decoder(STFTFB(**stft_enc.get_config(), window=syn_win)) # If no intitial phase is provided initialize uniformly if angles is None: angles = 2 * math.pi * torch.rand_like(mag_specgrams, device=mag_specgrams.device) # wav_dim is used in mixture_consistency. # Transform spec src dim to wav src dim for positive and negative dim wav_dim = dim if dim >= 0 else dim + 1 # We forward/backward the mixture through STFT to have matching shapes # with the input spectrograms as well as account for potential modulations # if the window were not chosen to enable perfect reconstruction. mixture_wav = istft_dec(stft_enc(mixture_wav)) # Initialize rebuilt (useful to use momentum) rebuilt = 0.0 for _ in range(n_iter): prev_built = rebuilt # Go to the time domain complex_specgram = transforms.from_mag_and_phase(mag_specgrams, angles) wavs = istft_dec(complex_specgram) # Make wavs sum up to the mixture consistent_wavs = mixture_consistency( mixture_wav, wavs, src_weights=src_weights, dim=wav_dim ) # Back to TF domain rebuilt = stft_enc(consistent_wavs) # Update phase estimates (with momentum). Keep the momentum here # in case. Was shown useful in GF, might be here. We'll see. diff = rebuilt - momentum / (1 + momentum) * prev_built angles = transforms.angle(diff) # Final source estimates final_complex_spec = transforms.from_mag_and_phase(mag_specgrams, angles) return istft_dec(final_complex_spec)
Read the Docs v: v0.3.4
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.