Shortcuts

asteroid.losses.mixit_wrapper module

class asteroid.losses.mixit_wrapper.MixITLossWrapper(loss_func, generalized=True)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Mixture invariant loss wrapper.

Parameters:
  • loss_func – function with signature (est_targets, targets, **kwargs).
  • generalized (bool) – Determines how MixIT is applied. If False , apply MixIT for any number of mixtures as soon as they contain the same number of sources (best_part_mixit().) If True (default), apply MixIT for two mixtures, but those mixtures do not necessarly have to contain the same number of sources. See best_part_mixit_generalized().

For each of these modes, the best partition and reordering will be automatically computed.

Examples

>>> import torch
>>> from asteroid.losses import multisrc_mse
>>> mixtures = torch.randn(10, 2, 16000)
>>> est_sources = torch.randn(10, 4, 16000)
>>> # Compute MixIT loss based on pairwise losses
>>> loss_func = MixITLossWrapper(multisrc_mse)
>>> loss_val = loss_func(est_sources, mixtures)
References
[1] Scott Wisdom et al. “Unsupervised sound separation using mixtures of mixtures.” arXiv:2006.12701 (2020)
forward(est_targets, targets, return_est=False, **kwargs)[source]

Find the best partition and return the loss.

Parameters:
  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, *)\). The batch of target estimates.
  • targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets
  • return_est – Boolean. Whether to return the estimated mixtures estimates (To compute metrics or to save example).
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

  • Best partition loss for each batch sample, average over the batch. torch.Tensor(loss_value)
  • The estimated mixtures (estimated sources summed according to the partition) if return_est is True. torch.Tensor of shape \((batch, nmix, ...)\).

static best_part_mixit(loss_func, est_targets, targets, **kwargs)[source]

Find best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm in [1]. Valid for any number of mixtures as soon as they contain the same number of sources.

Parameters:
  • loss_func – function with signature (est_targets, targets, **kwargs) The loss function to get batch losses from.
  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
  • targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets (mixtures).
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

  • torch.Tensor – The loss corresponding to the best permutation of size (batch,).
  • torch.LongTensor: The indices of the best partition.
  • list: list of the possible partitions of the sources.

static best_part_mixit_generalized(loss_func, est_targets, targets, **kwargs)[source]

Find best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm in [1]. Valid only for two mixtures, but those mixtures do not necessarly have to contain the same number of sources e.g the case where one mixture is silent is allowed..

Parameters:
  • loss_func – function with signature (est_targets, targets, **kwargs) The loss function to get batch losses from.
  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
  • targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets (mixtures).
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

  • torch.Tensor – The loss corresponding to the best permutation of size (batch,).
  • torch.LongTensor: The indexes of the best permutations.
  • list: list of the possible partitions of the sources.

static loss_set_from_parts(loss_func, est_targets, targets, parts, **kwargs)[source]

Common loop between both best_part_mixit

static reorder_source(est_targets, targets, min_loss_idx, parts)[source]

Reorder sources according to the best partition.

Parameters:
  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
  • targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets.
  • min_loss_idx – torch.LongTensor. The indexes of the best permutations.
  • parts – list of the possible partitions of the sources.
Returns:

torch.Tensor – Reordered sources of shape \((batch, nmix, time)\).

Read the Docs v: v0.4.4
Versions
latest
stable
v0.4.4
v0.4.3
v0.4.2
v0.4.1
v0.4.0
v0.3.5_b
v0.3.4
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.