
Source code for asteroid.dsp.consistency

import torch

[docs]def mixture_consistency(mixture, est_sources, src_weights=None, dim=1): """Applies mixture consistency to a tensor of estimated sources. Args mixture (torch.Tensor): Mixture waveform or TF representation. est_sources (torch.Tensor): Estimated sources waveforms or TF representations. src_weights (torch.Tensor): Consistency weight for each source. Shape needs to be broadcastable to `est_source`. We make sure that the weights sum up to 1 along dim `dim`. If `src_weights` is None, compute them based on relative power. dim (int): Axis which contains the sources in `est_sources`. Returns torch.Tensor with same shape as `est_sources`, after applying mixture consistency. Notes This method can be used only in 'complete' separation tasks, otherwise the residual error will contain unwanted sources. For example, this won't work with the task `sep_noisy` from WHAM. Examples >>> # Works on waveforms >>> mix = torch.randn(10, 16000) >>> est_sources = torch.randn(10, 2, 16000) >>> new_est_sources = mixture_consistency(mix, est_sources, dim=1) >>> # Also works on spectrograms >>> mix = torch.randn(10, 514, 400) >>> est_sources = torch.randn(10, 2, 514, 400) >>> new_est_sources = mixture_consistency(mix, est_sources, dim=1) References Scott Wisdom, John R Hershey, Kevin Wilson, Jeremy Thorpe, Michael Chinen, Brian Patton, and Rif A Saurous. "Differentiable consistency constraints for improved deep speech enhancement", ICASSP 2019. """ # If the source weights are not specified, the weights are the relative # power of each source to the sum. w_i = P_i / (P_all), P for power. if src_weights is None: all_dims = list(range(est_sources.ndim)) all_dims.pop(dim) # Remove source axis all_dims.pop(0) # Remove batch dim src_weights = torch.mean(est_sources ** 2, dim=all_dims, keepdim=True) # Make sure that the weights sum up to 1 norm_weights = torch.sum(src_weights, dim=dim, keepdim=True) + 1e-8 src_weights = src_weights / norm_weights # Compute residual mix - sum(est_sources) if mixture.ndim == est_sources.ndim - 1: # mixture (batch, *), est_sources (batch, n_src, *) residual = (mixture - est_sources.sum(dim=dim)).unsqueeze(dim) elif mixture.ndim == est_sources.ndim: # mixture (batch, 1, *), est_sources (batch, n_src, *) residual = mixture - est_sources.sum(dim=dim, keepdim=True) else: n, m = est_sources.ndim, mixture.ndim raise RuntimeError( f"The size of the mixture tensor should match the " f"size of the est_sources tensor. Expected mixture" f"tensor to have {n} or {n-1} dimension, found {m}." ) # Compute remove new_sources = est_sources + src_weights * residual return new_sources
Read the Docs v: v0.3.5
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.