asteroid.losses.mixit_wrapper module¶
-
class
asteroid.losses.mixit_wrapper.
MixITLossWrapper
(loss_func, generalized=True)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Mixture invariant loss wrapper.
Parameters: - loss_func – function with signature (est_targets, targets, **kwargs).
- generalized (bool) – Determines how MixIT is applied. If False ,
apply MixIT for any number of mixtures as soon as they contain
the same number of sources (
best_part_mixit()
.) If True (default), apply MixIT for two mixtures, but those mixtures do not necessarly have to contain the same number of sources. Seebest_part_mixit_generalized()
.
For each of these modes, the best partition and reordering will be automatically computed.
Examples
>>> import torch >>> from asteroid.losses import multisrc_mse >>> mixtures = torch.randn(10, 2, 16000) >>> est_sources = torch.randn(10, 4, 16000) >>> # Compute MixIT loss based on pairwise losses >>> loss_func = MixITLossWrapper(multisrc_mse) >>> loss_val = loss_func(est_sources, mixtures)
- References
- [1] Scott Wisdom et al. “Unsupervised sound separation using mixtures of mixtures.” arXiv:2006.12701 (2020)
-
forward
(est_targets, targets, return_est=False, **kwargs)[source]¶ Find the best partition and return the loss.
Parameters: - est_targets – torch.Tensor. Expected shape \((batch, nsrc, *)\). The batch of target estimates.
- targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets
- return_est – Boolean. Whether to return the estimated mixtures estimates (To compute metrics or to save example).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: - Best partition loss for each batch sample, average over the batch. torch.Tensor(loss_value)
- The estimated mixtures (estimated sources summed according to the partition) if return_est is True. torch.Tensor of shape \((batch, nmix, ...)\).
-
static
best_part_mixit
(loss_func, est_targets, targets, **kwargs)[source]¶ Find best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm in [1]. Valid for any number of mixtures as soon as they contain the same number of sources.
Parameters: - loss_func – function with signature
(est_targets, targets, **kwargs)
The loss function to get batch losses from. - est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
- targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets (mixtures).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: torch.Tensor
– The loss corresponding to the best permutation of size (batch,).torch.LongTensor
: The indices of the best partition.list
: list of the possible partitions of the sources.
- loss_func – function with signature
-
static
best_part_mixit_generalized
(loss_func, est_targets, targets, **kwargs)[source]¶ Find best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm in [1]. Valid only for two mixtures, but those mixtures do not necessarly have to contain the same number of sources e.g the case where one mixture is silent is allowed..
Parameters: - loss_func – function with signature
(est_targets, targets, **kwargs)
The loss function to get batch losses from. - est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
- targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets (mixtures).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: torch.Tensor
– The loss corresponding to the best permutation of size (batch,).torch.LongTensor
: The indexes of the best permutations.list
: list of the possible partitions of the sources.
- loss_func – function with signature
-
static
loss_set_from_parts
(loss_func, est_targets, targets, parts, **kwargs)[source]¶ Common loop between both best_part_mixit
-
static
reorder_source
(est_targets, targets, min_loss_idx, parts)[source]¶ Reorder sources according to the best partition.
Parameters: - est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
- targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets.
- min_loss_idx – torch.LongTensor. The indexes of the best permutations.
- parts – list of the possible partitions of the sources.
Returns: torch.Tensor
– Reordered sources of shape \((batch, nmix, time)\).