Shortcuts

asteroid.losses.pit_wrapper module

class asteroid.losses.pit_wrapper.PITLossWrapper(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Permutation invariant loss wrapper.

Parameters:
  • loss_func – function with signature (targets, est_targets, **kwargs).
  • pit_from (str) –

    Determines how PIT is applied.

    • 'pw_mtx' (pairwise matrix): loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \([batch, i, j]\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\)
    • 'pw_pt' (pairwise point): loss_func computes the loss for a batch of single source and single estimates (tensors won’t have the source axis). Output shape : \((batch)\). See get_pw_losses().
    • ``’perm_avg’``(permutation average): loss_func computes the average loss for a given permutations of the sources and estimates. Output shape : \((batch)\). See best_perm_from_perm_avg_loss().

    In terms of efficiency, 'perm_avg' is the least efficicient.

  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) –> (B, n_src!). perm_reduce can receive **kwargs during forward using the reduce_kwargs argument (dict). If those argument are static, consider defining a small function or using functools.partial. Only used in ‘pw_mtx’ and ‘pw_pt’ pit_from modes.

For each of these modes, the best permutation and reordering will be automatically computed.

Examples

>>> import torch
>>> from asteroid.losses import pairwise_neg_sisdr
>>> sources = torch.randn(10, 3, 16000)
>>> est_sources = torch.randn(10, 3, 16000)
>>> # Compute PIT loss based on pairwise losses
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
>>> loss_val = loss_func(est_sources, sources)
>>>
>>> # Using reduce
>>> def reduce(perm_loss, src):
>>>     weighted = perm_loss * src.norm(dim=-1, keepdim=True)
>>>     return torch.mean(weighted, dim=-1)
>>>
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx',
>>>                            perm_reduce=reduce)
>>> reduce_kwargs = {'src': sources}
>>> loss_val = loss_func(est_sources, sources,
>>>                      reduce_kwargs=reduce_kwargs)
static best_perm_from_perm_avg_loss(loss_func, est_targets, targets, **kwargs)[source]

Find best permutation from loss function with source axis.

Parameters:
  • loss_func – function with signature (targets, est_targets, **kwargs) The loss function batch losses from.
  • est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
  • targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets.
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

tupletorch.Tensor: The loss corresponding to the best permutation of size (batch,).

torch.LongTensor: The indexes of the best permutations.

static find_best_perm(pair_wise_losses, n_src, perm_reduce=None, **kwargs)[source]

Find the best permutation, given the pair-wise losses.

Parameters:
  • pair_wise_losses (torch.Tensor) – Tensor of shape [batch, n_src, n_src]. Pairwise losses.
  • n_src (int) – Number of sources.
  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) –> (B, n_src!)
  • **kwargs – additional keyword argument that will be passed to the permutation reduce function.
Returns:

tupletorch.Tensor: The loss corresponding to the best permutation of size (batch,).

torch.LongTensor: The indexes of the best permutations.

MIT Copyright (c) 2018 Kaituo XU. See Original code and License.

forward(est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs)[source]

Find the best permutation and return the loss.

Parameters:
  • est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
  • targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets
  • return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
  • reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

  • Best permutation loss for each batch sample, average over
    the batch. torch.Tensor(loss_value)
  • The reordered targets estimates if return_est is True.
    torch.Tensor of shape [batch, nsrc, *].

static get_pw_losses(loss_func, est_targets, targets, **kwargs)[source]

Get pair-wise losses between the training targets and its estimate for a given loss function.

Parameters:
  • loss_func – function with signature (targets, est_targets, **kwargs) The loss function to get pair-wise losses from.
  • est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
  • targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets.
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

torch.Tensor or size [batch, nsrc, nsrc], losses computed for all permutations of the targets and est_targets.

This function can be called on a loss function which returns a tensor of size [batch]. There are more efficient ways to compute pair-wise losses using broadcasting.

static reorder_source(source, n_src, min_loss_idx)[source]

Reorder sources according to the best permutation.

Parameters:
  • source (torch.Tensor) – Tensor of shape [batch, n_src, time]
  • n_src (int) – Number of sources.
  • min_loss_idx (torch.LongTensor) – Tensor of shape [batch], each item is in [0, n_src!).
Returns:

torch.Tensor – Reordered sources of shape [batch, n_src, time].

MIT Copyright (c) 2018 Kaituo XU. See Original code and License.

Read the Docs v: v0.3.5
Versions
latest
stable
v0.3.5_b
v0.3.4
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.