asteroid.losses.pit_wrapper module¶
-
class
asteroid.losses.pit_wrapper.
PITLossWrapper
(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Permutation invariant loss wrapper.
Parameters: - loss_func – function with signature (targets, est_targets, **kwargs).
- pit_from (str) –
Determines how PIT is applied.
'pw_mtx'
(pairwise matrix): loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \([batch, i, j]\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\)'pw_pt'
(pairwise point): loss_func computes the loss for a batch of single source and single estimates (tensors won’t have the source axis). Output shape : \((batch)\). Seeget_pw_losses()
.- ``’perm_avg’``(permutation average): loss_func computes the
average loss for a given permutations of the sources and
estimates. Output shape : \((batch)\).
See
best_perm_from_perm_avg_loss()
.
In terms of efficiency,
'perm_avg'
is the least efficicient. - perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) –> (B, n_src!). perm_reduce can receive **kwargs during forward using the reduce_kwargs argument (dict). If those argument are static, consider defining a small function or using functools.partial. Only used in ‘pw_mtx’ and ‘pw_pt’ pit_from modes.
For each of these modes, the best permutation and reordering will be automatically computed.
Examples
>>> import torch >>> from asteroid.losses import pairwise_neg_sisdr >>> sources = torch.randn(10, 3, 16000) >>> est_sources = torch.randn(10, 3, 16000) >>> # Compute PIT loss based on pairwise losses >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') >>> loss_val = loss_func(est_sources, sources) >>> >>> # Using reduce >>> def reduce(perm_loss, src): >>> weighted = perm_loss * src.norm(dim=-1, keepdim=True) >>> return torch.mean(weighted, dim=-1) >>> >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx', >>> perm_reduce=reduce) >>> reduce_kwargs = {'src': sources} >>> loss_val = loss_func(est_sources, sources, >>> reduce_kwargs=reduce_kwargs)
-
static
best_perm_from_perm_avg_loss
(loss_func, est_targets, targets, **kwargs)[source]¶ Find best permutation from loss function with source axis.
Parameters: - loss_func – function with signature (targets, est_targets, **kwargs) The loss function batch losses from.
- est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
- targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets.
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: tuple –
torch.Tensor
: The loss corresponding to the best permutation of size (batch,).torch.LongTensor
: The indexes of the best permutations.
-
static
find_best_perm
(pair_wise_losses, n_src, perm_reduce=None, **kwargs)[source]¶ Find the best permutation, given the pair-wise losses.
Parameters: - pair_wise_losses (
torch.Tensor
) – Tensor of shape [batch, n_src, n_src]. Pairwise losses. - n_src (int) – Number of sources.
- perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) –> (B, n_src!)
- **kwargs – additional keyword argument that will be passed to the permutation reduce function.
Returns: tuple –
torch.Tensor
: The loss corresponding to the best permutation of size (batch,).torch.LongTensor
: The indexes of the best permutations.MIT Copyright (c) 2018 Kaituo XU. See Original code and License.
- pair_wise_losses (
-
forward
(est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs)[source]¶ Find the best permutation and return the loss.
Parameters: - est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
- targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets
- return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
- reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: - Best permutation loss for each batch sample, average over
- the batch. torch.Tensor(loss_value)
- The reordered targets estimates if return_est is True.
- torch.Tensor of shape [batch, nsrc, *].
-
static
get_pw_losses
(loss_func, est_targets, targets, **kwargs)[source]¶ Get pair-wise losses between the training targets and its estimate for a given loss function.
Parameters: - loss_func – function with signature (targets, est_targets, **kwargs) The loss function to get pair-wise losses from.
- est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
- targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets.
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: torch.Tensor or size [batch, nsrc, nsrc], losses computed for all permutations of the targets and est_targets.
This function can be called on a loss function which returns a tensor of size [batch]. There are more efficient ways to compute pair-wise losses using broadcasting.
-
static
reorder_source
(source, n_src, min_loss_idx)[source]¶ Reorder sources according to the best permutation.
Parameters: - source (torch.Tensor) – Tensor of shape [batch, n_src, time]
- n_src (int) – Number of sources.
- min_loss_idx (torch.LongTensor) – Tensor of shape [batch], each item is in [0, n_src!).
Returns: torch.Tensor
– Reordered sources of shape [batch, n_src, time].MIT Copyright (c) 2018 Kaituo XU. See Original code and License.