Source code for asteroid.dsp.consistency
import torch
[docs]def mixture_consistency(mixture, est_sources, src_weights=None, dim=1):
""" Applies mixture consistency to a tensor of estimated sources.
Args
mixture (torch.Tensor): Mixture waveform or TF representation.
est_sources (torch.Tensor): Estimated sources waveforms or TF
representations.
src_weights (torch.Tensor): Consistency weight for each source.
Shape needs to be broadcastable to `est_source`.
We make sure that the weights sum up to 1 along dim `dim`.
If `src_weights` is None, compute them based on relative power.
dim (int): Axis which contains the sources in `est_sources`.
Returns
torch.Tensor with same shape as `est_sources`, after applying mixture
consistency.
Notes
This method can be used only in 'complete' separation tasks, otherwise
the residual error will contain unwanted sources. For example, this
won't work with the task `sep_noisy` from WHAM.
Examples
>>> # Works on waveforms
>>> mix = torch.randn(10, 16000)
>>> est_sources = torch.randn(10, 2, 16000)
>>> new_est_sources = mixture_consistency(mix, est_sources, dim=1)
>>> # Also works on spectrograms
>>> mix = torch.randn(10, 514, 400)
>>> est_sources = torch.randn(10, 2, 514, 400)
>>> new_est_sources = mixture_consistency(mix, est_sources, dim=1)
References
Scott Wisdom, John R Hershey, Kevin Wilson, Jeremy Thorpe, Michael
Chinen, Brian Patton, and Rif A Saurous. "Differentiable consistency
constraints for improved deep speech enhancement", ICASSP 2019.
"""
# If the source weights are not specified, the weights are the relative
# power of each source to the sum. w_i = P_i / (P_all), P for power.
if src_weights is None:
all_dims = list(range(est_sources.ndim))
all_dims.pop(dim) # Remove source axis
all_dims.pop(0) # Remove batch dim
src_weights = torch.mean(est_sources ** 2, dim=all_dims, keepdim=True)
# Make sure that the weights sum up to 1
norm_weights = torch.sum(src_weights, dim=dim, keepdim=True) + 1e-8
src_weights = src_weights / norm_weights
# Compute residual mix - sum(est_sources)
if mixture.ndim == est_sources.ndim - 1:
# mixture (batch, *), est_sources (batch, n_src, *)
residual = (mixture - est_sources.sum(dim=dim)).unsqueeze(dim)
elif mixture.ndim == est_sources.ndim:
# mixture (batch, 1, *), est_sources (batch, n_src, *)
residual = mixture - est_sources.sum(dim=dim, keepdim=True)
else:
n, m = est_sources.ndim, mixture.ndim
raise RuntimeError(
f"The size of the mixture tensor should match the "
f"size of the est_sources tensor. Expected mixture"
f"tensor to have {n} or {n-1} dimension, found {m}."
)
# Compute remove
new_sources = est_sources + src_weights * residual
return new_sources