asteroid.losses.pit_wrapper module¶
-
class
asteroid.losses.pit_wrapper.
PITLossWrapper
(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Permutation invariant loss wrapper.
Parameters: - loss_func – function with signature (est_targets, targets, **kwargs).
- pit_from (str) –
Determines how PIT is applied.
'pw_mtx'
(pairwise matrix): loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \((batch, i, j)\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\)'pw_pt'
(pairwise point): loss_func computes the loss for a batch of single source and single estimates (tensors won’t have the source axis). Output shape : \((batch)\). Seeget_pw_losses()
.'perm_avg'
(permutation average): loss_func computes the average loss for a given permutations of the sources and estimates. Output shape : \((batch)\). Seebest_perm_from_perm_avg_loss()
.
In terms of efficiency,
'perm_avg'
is the least efficicient. - perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) --> (B, n\_src!)\). perm_reduce can receive **kwargs during forward using the reduce_kwargs argument (dict). If those argument are static, consider defining a small function or using functools.partial. Only used in ‘pw_mtx’ and ‘pw_pt’ pit_from modes.
For each of these modes, the best permutation and reordering will be automatically computed. When either
'pw_mtx'
or'pw_pt'
is used, and the number of sources is larger than three, the hungarian algorithm is used to find the best permutation.- Examples
>>> import torch >>> from asteroid.losses import pairwise_neg_sisdr >>> sources = torch.randn(10, 3, 16000) >>> est_sources = torch.randn(10, 3, 16000) >>> # Compute PIT loss based on pairwise losses >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') >>> loss_val = loss_func(est_sources, sources) >>> >>> # Using reduce >>> def reduce(perm_loss, src): >>> weighted = perm_loss * src.norm(dim=-1, keepdim=True) >>> return torch.mean(weighted, dim=-1) >>> >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx', >>> perm_reduce=reduce) >>> reduce_kwargs = {'src': sources} >>> loss_val = loss_func(est_sources, sources, >>> reduce_kwargs=reduce_kwargs)
-
forward
(est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs)[source]¶ Find the best permutation and return the loss.
Parameters: - est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.
- targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets
- return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
- reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: - Best permutation loss for each batch sample, average over the batch.
- The reordered targets estimates if
return_est
is True.torch.Tensor
of shape $(batch, nsrc, …)$.
-
static
get_pw_losses
(loss_func, est_targets, targets, **kwargs)[source]¶ Get pair-wise losses between the training targets and its estimate for a given loss function.
Parameters: - loss_func – function with signature (est_targets, targets, **kwargs) The loss function to get pair-wise losses from.
- est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.
- targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets.
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: torch.Tensor or size $(batch, nsrc, nsrc)$, losses computed for all permutations of the targets and est_targets.
This function can be called on a loss function which returns a tensor of size \((batch)\). There are more efficient ways to compute pair-wise losses using broadcasting.
-
static
best_perm_from_perm_avg_loss
(loss_func, est_targets, targets, **kwargs)[source]¶ Find best permutation from loss function with source axis.
Parameters: - loss_func – function with signature $(est_targets, targets, **kwargs)$ The loss function batch losses from.
- est_targets – torch.Tensor. Expected shape $(batch, nsrc, *)$. The batch of target estimates.
- targets – torch.Tensor. Expected shape $(batch, nsrc, *)$. The batch of training targets.
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: torch.Tensor
– The loss corresponding to the best permutation of size $(batch,)$.torch.Tensor
:- The indices of the best permutations.
-
static
find_best_perm
(pair_wise_losses, perm_reduce=None, **kwargs)[source]¶ Find the best permutation, given the pair-wise losses.
Dispatch between factorial method if number of sources is small (<3) and hungarian method for more sources. If
perm_reduce
is not None, the factorial method is always used.Parameters: - pair_wise_losses (
torch.Tensor
) – Tensor of shape \((batch, n\_src, n\_src)\). Pairwise losses. - perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) -> (B, n\_src!)\)
- **kwargs – additional keyword argument that will be passed to the permutation reduce function.
Returns: torch.Tensor
– The loss corresponding to the best permutation of size $(batch,)$.torch.Tensor
: The indices of the best permutations.
- pair_wise_losses (
-
static
reorder_source
(source, batch_indices)[source]¶ Reorder sources according to the best permutation.
Parameters: - source (torch.Tensor) – Tensor of shape \((batch, n_src, time)\)
- batch_indices (torch.Tensor) – Tensor of shape \((batch, n_src)\). Contains optimal permutation indices for each batch.
Returns: torch.Tensor
– Reordered sources.
-
static
find_best_perm_factorial
(pair_wise_losses, perm_reduce=None, **kwargs)[source]¶ Find the best permutation given the pair-wise losses by looping through all the permutations.
Parameters: - pair_wise_losses (
torch.Tensor
) – Tensor of shape \((batch, n_src, n_src)\). Pairwise losses. - perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) -> (B, n\_src!)\)
- **kwargs – additional keyword argument that will be passed to the permutation reduce function.
Returns: torch.Tensor
– The loss corresponding to the best permutation of size $(batch,)$.torch.Tensor
: The indices of the best permutations.
MIT Copyright (c) 2018 Kaituo XU. See Original code and License.
- pair_wise_losses (
-
static
find_best_perm_hungarian
(pair_wise_losses: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85d6096790>)[source]¶ Find the best permutation given the pair-wise losses, using the Hungarian algorithm.
Returns: torch.Tensor
– The loss corresponding to the best permutation of size (batch,).torch.Tensor
: The indices of the best permutations.
-
class
asteroid.losses.pit_wrapper.
PITReorder
(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]¶ Bases:
asteroid.losses.pit_wrapper.PITLossWrapper
Permutation invariant reorderer. Only returns the reordered estimates. See :py:class:asteroid.losses.PITLossWrapper.
-
forward
(est_targets, targets, reduce_kwargs=None, **kwargs)[source]¶ Find the best permutation and return the loss.
Parameters: - est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.
- targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets
- return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
- reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: - Best permutation loss for each batch sample, average over the batch.
- The reordered targets estimates if
return_est
is True.torch.Tensor
of shape $(batch, nsrc, …)$.
-