Shortcuts

asteroid.losses.pit_wrapper module

class asteroid.losses.pit_wrapper.PITLossWrapper(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Permutation invariant loss wrapper.

Parameters:
  • loss_func – function with signature (est_targets, targets, **kwargs).
  • pit_from (str) –

    Determines how PIT is applied.

    • 'pw_mtx' (pairwise matrix): loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \((batch, i, j)\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\)
    • 'pw_pt' (pairwise point): loss_func computes the loss for a batch of single source and single estimates (tensors won’t have the source axis). Output shape : \((batch)\). See get_pw_losses().
    • 'perm_avg' (permutation average): loss_func computes the average loss for a given permutations of the sources and estimates. Output shape : \((batch)\). See best_perm_from_perm_avg_loss().

    In terms of efficiency, 'perm_avg' is the least efficicient.

  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) --> (B, n\_src!)\). perm_reduce can receive **kwargs during forward using the reduce_kwargs argument (dict). If those argument are static, consider defining a small function or using functools.partial. Only used in ‘pw_mtx’ and ‘pw_pt’ pit_from modes.

For each of these modes, the best permutation and reordering will be automatically computed. When either 'pw_mtx' or 'pw_pt' is used, and the number of sources is larger than three, the hungarian algorithm is used to find the best permutation.

Examples
>>> import torch
>>> from asteroid.losses import pairwise_neg_sisdr
>>> sources = torch.randn(10, 3, 16000)
>>> est_sources = torch.randn(10, 3, 16000)
>>> # Compute PIT loss based on pairwise losses
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
>>> loss_val = loss_func(est_sources, sources)
>>>
>>> # Using reduce
>>> def reduce(perm_loss, src):
>>>     weighted = perm_loss * src.norm(dim=-1, keepdim=True)
>>>     return torch.mean(weighted, dim=-1)
>>>
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx',
>>>                            perm_reduce=reduce)
>>> reduce_kwargs = {'src': sources}
>>> loss_val = loss_func(est_sources, sources,
>>>                      reduce_kwargs=reduce_kwargs)
forward(est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs)[source]

Find the best permutation and return the loss.

Parameters:
  • est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.
  • targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets
  • return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
  • reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

  • Best permutation loss for each batch sample, average over the batch.
  • The reordered targets estimates if return_est is True. torch.Tensor of shape $(batch, nsrc, …)$.

static get_pw_losses(loss_func, est_targets, targets, **kwargs)[source]

Get pair-wise losses between the training targets and its estimate for a given loss function.

Parameters:
  • loss_func – function with signature (est_targets, targets, **kwargs) The loss function to get pair-wise losses from.
  • est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.
  • targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets.
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

torch.Tensor or size $(batch, nsrc, nsrc)$, losses computed for all permutations of the targets and est_targets.

This function can be called on a loss function which returns a tensor of size \((batch)\). There are more efficient ways to compute pair-wise losses using broadcasting.

static best_perm_from_perm_avg_loss(loss_func, est_targets, targets, **kwargs)[source]

Find best permutation from loss function with source axis.

Parameters:
  • loss_func – function with signature $(est_targets, targets, **kwargs)$ The loss function batch losses from.
  • est_targets – torch.Tensor. Expected shape $(batch, nsrc, *)$. The batch of target estimates.
  • targets – torch.Tensor. Expected shape $(batch, nsrc, *)$. The batch of training targets.
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

  • torch.Tensor – The loss corresponding to the best permutation of size $(batch,)$.
  • torch.Tensor:
    The indices of the best permutations.

static find_best_perm(pair_wise_losses, perm_reduce=None, **kwargs)[source]

Find the best permutation, given the pair-wise losses.

Dispatch between factorial method if number of sources is small (<3) and hungarian method for more sources. If perm_reduce is not None, the factorial method is always used.

Parameters:
  • pair_wise_losses (torch.Tensor) – Tensor of shape \((batch, n\_src, n\_src)\). Pairwise losses.
  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) -> (B, n\_src!)\)
  • **kwargs – additional keyword argument that will be passed to the permutation reduce function.
Returns:

  • torch.Tensor – The loss corresponding to the best permutation of size $(batch,)$.
  • torch.Tensor: The indices of the best permutations.

static reorder_source(source, batch_indices)[source]

Reorder sources according to the best permutation.

Parameters:
  • source (torch.Tensor) – Tensor of shape \((batch, n_src, time)\)
  • batch_indices (torch.Tensor) – Tensor of shape \((batch, n_src)\). Contains optimal permutation indices for each batch.
Returns:

torch.Tensor – Reordered sources.

static find_best_perm_factorial(pair_wise_losses, perm_reduce=None, **kwargs)[source]

Find the best permutation given the pair-wise losses by looping through all the permutations.

Parameters:
  • pair_wise_losses (torch.Tensor) – Tensor of shape \((batch, n_src, n_src)\). Pairwise losses.
  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) -> (B, n\_src!)\)
  • **kwargs – additional keyword argument that will be passed to the permutation reduce function.
Returns:

  • torch.Tensor – The loss corresponding to the best permutation of size $(batch,)$.
  • torch.Tensor: The indices of the best permutations.

MIT Copyright (c) 2018 Kaituo XU. See Original code and License.

static find_best_perm_hungarian(pair_wise_losses: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85d6096790>)[source]

Find the best permutation given the pair-wise losses, using the Hungarian algorithm.

Returns:
  • torch.Tensor – The loss corresponding to the best permutation of size (batch,).
  • torch.Tensor: The indices of the best permutations.
class asteroid.losses.pit_wrapper.PITReorder(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]

Bases: asteroid.losses.pit_wrapper.PITLossWrapper

Permutation invariant reorderer. Only returns the reordered estimates. See :py:class:asteroid.losses.PITLossWrapper.

forward(est_targets, targets, reduce_kwargs=None, **kwargs)[source]

Find the best permutation and return the loss.

Parameters:
  • est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.
  • targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets
  • return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
  • reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

  • Best permutation loss for each batch sample, average over the batch.
  • The reordered targets estimates if return_est is True. torch.Tensor of shape $(batch, nsrc, …)$.

Read the Docs v: v0.4.4
Versions
latest
stable
v0.4.4
v0.4.3
v0.4.2
v0.4.1
v0.4.0
v0.3.5_b
v0.3.4
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.