Source code for asteroid.losses.pit_wrapper

from itertools import permutations
import torch
from torch import nn

[docs]class PITLossWrapper(nn.Module): """ Permutation invariant loss wrapper. Args: loss_func: function with signature (targets, est_targets, **kwargs). pit_from (str): Determines how PIT is applied. * ``'pw_mtx'`` (pairwise matrix): `loss_func` computes pairwise losses and returns a torch.Tensor of shape :math:`(batch, n\_src, n\_src)`. Each element :math:`[batch, i, j]` corresponds to the loss between :math:`targets[:, i]` and :math:`est\_targets[:, j]` * ``'pw_pt'`` (pairwise point): `loss_func` computes the loss for a batch of single source and single estimates (tensors won't have the source axis). Output shape : :math:`(batch)`. See :meth:`~PITLossWrapper.get_pw_losses`. * ``'perm_avg'``(permutation average): `loss_func` computes the average loss for a given permutations of the sources and estimates. Output shape : :math:`(batch)`. See :meth:`~PITLossWrapper.best_perm_from_perm_avg_loss`. In terms of efficiency, ``'perm_avg'`` is the least efficicient. perm_reduce (Callable): torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) --> (B, n_src!). `perm_reduce` can receive **kwargs during forward using the `reduce_kwargs` argument (dict). If those argument are static, consider defining a small function or using `functools.partial`. Only used in `'pw_mtx'` and `'pw_pt'` `pit_from` modes. For each of these modes, the best permutation and reordering will be automatically computed. Examples: >>> import torch >>> from asteroid.losses import pairwise_neg_sisdr >>> sources = torch.randn(10, 3, 16000) >>> est_sources = torch.randn(10, 3, 16000) >>> # Compute PIT loss based on pairwise losses >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') >>> loss_val = loss_func(est_sources, sources) >>> >>> # Using reduce >>> def reduce(perm_loss, src): >>> weighted = perm_loss * src.norm(dim=-1, keepdim=True) >>> return torch.mean(weighted, dim=-1) >>> >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx', >>> perm_reduce=reduce) >>> reduce_kwargs = {'src': sources} >>> loss_val = loss_func(est_sources, sources, >>> reduce_kwargs=reduce_kwargs) """ def __init__(self, loss_func, pit_from="pw_mtx", perm_reduce=None): super().__init__() self.loss_func = loss_func self.pit_from = pit_from self.perm_reduce = perm_reduce if self.pit_from not in ["pw_mtx", "pw_pt", "perm_avg"]: raise ValueError( "Unsupported loss function type for now. Expected" "one of [`pw_mtx`, `pw_pt`, `perm_avg`]" )
[docs] def forward(self, est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs): """ Find the best permutation and return the loss. Args: est_targets: torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates. targets: torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets return_est: Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example). reduce_kwargs (dict or None): kwargs that will be passed to the pairwise losses reduce function (`perm_reduce`). **kwargs: additional keyword argument that will be passed to the loss function. Returns: - Best permutation loss for each batch sample, average over the batch. torch.Tensor(loss_value) - The reordered targets estimates if return_est is True. torch.Tensor of shape [batch, nsrc, *]. """ n_src = targets.shape[1] assert n_src < 10, f"Expected source axis along dim 1, found {n_src}" if self.pit_from == "pw_mtx": # Loss function already returns pairwise losses pw_losses = self.loss_func(est_targets, targets, **kwargs) elif self.pit_from == "pw_pt": # Compute pairwise losses with a for loop. pw_losses = self.get_pw_losses(self.loss_func, est_targets, targets, **kwargs) elif self.pit_from == "perm_avg": # Cannot get pairwise losses from this type of loss. # Find best permutation directly. min_loss, min_loss_idx = self.best_perm_from_perm_avg_loss( self.loss_func, est_targets, targets, **kwargs ) # Take the mean over the batch mean_loss = torch.mean(min_loss) if not return_est: return mean_loss reordered = self.reorder_source(est_targets, n_src, min_loss_idx) return mean_loss, reordered else: return assert pw_losses.ndim == 3, ( "Something went wrong with the loss " "function, please read the docs." ) assert pw_losses.shape[0] == targets.shape[0], "PIT loss needs same batch dim as input" reduce_kwargs = reduce_kwargs if reduce_kwargs is not None else dict() min_loss, min_loss_idx = self.find_best_perm( pw_losses, n_src, perm_reduce=self.perm_reduce, **reduce_kwargs ) mean_loss = torch.mean(min_loss) if not return_est: return mean_loss reordered = self.reorder_source(est_targets, n_src, min_loss_idx) return mean_loss, reordered
[docs] @staticmethod def get_pw_losses(loss_func, est_targets, targets, **kwargs): """ Get pair-wise losses between the training targets and its estimate for a given loss function. Args: loss_func: function with signature (targets, est_targets, **kwargs) The loss function to get pair-wise losses from. est_targets: torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates. targets: torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets. **kwargs: additional keyword argument that will be passed to the loss function. Returns: torch.Tensor or size [batch, nsrc, nsrc], losses computed for all permutations of the targets and est_targets. This function can be called on a loss function which returns a tensor of size [batch]. There are more efficient ways to compute pair-wise losses using broadcasting. """ batch_size, n_src, *_ = targets.shape pair_wise_losses = targets.new_empty(batch_size, n_src, n_src) for est_idx, est_src in enumerate(est_targets.transpose(0, 1)): for target_idx, target_src in enumerate(targets.transpose(0, 1)): pair_wise_losses[:, est_idx, target_idx] = loss_func(est_src, target_src, **kwargs) return pair_wise_losses
[docs] @staticmethod def best_perm_from_perm_avg_loss(loss_func, est_targets, targets, **kwargs): """ Find best permutation from loss function with source axis. Args: loss_func: function with signature (targets, est_targets, **kwargs) The loss function batch losses from. est_targets: torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates. targets: torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets. **kwargs: additional keyword argument that will be passed to the loss function. Returns: tuple: :class:`torch.Tensor`: The loss corresponding to the best permutation of size (batch,). :class:`torch.LongTensor`: The indexes of the best permutations. """ n_src = targets.shape[1] perms = list(permutations(range(n_src))) loss_set = torch.stack( [loss_func(est_targets[:, perm], targets, **kwargs) for perm in perms], dim=1 ) # Indexes and values of min losses for each batch element min_loss, min_loss_idx = torch.min(loss_set, dim=1, keepdim=True) return min_loss, min_loss_idx[:, 0]
[docs] @staticmethod def find_best_perm(pair_wise_losses, n_src, perm_reduce=None, **kwargs): """Find the best permutation, given the pair-wise losses. Args: pair_wise_losses (:class:`torch.Tensor`): Tensor of shape [batch, n_src, n_src]. Pairwise losses. n_src (int): Number of sources. perm_reduce (Callable): torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) --> (B, n_src!) **kwargs: additional keyword argument that will be passed to the permutation reduce function. Returns: tuple: :class:`torch.Tensor`: The loss corresponding to the best permutation of size (batch,). :class:`torch.LongTensor`: The indexes of the best permutations. MIT Copyright (c) 2018 Kaituo XU. See `Original code <>`__ and `License <>`__. """ # After transposition, dim 1 corresp. to sources and dim 2 to estimates pwl = pair_wise_losses.transpose(-1, -2) perms = pwl.new_tensor(list(permutations(range(n_src))), dtype=torch.long) # Column permutation indices idx = torch.unsqueeze(perms, 2) # Loss mean of each permutation if perm_reduce is None: # one-hot, [n_src!, n_src, n_src] perms_one_hot = pwl.new_zeros((*perms.size(), n_src)).scatter_(2, idx, 1) loss_set = torch.einsum("bij,pij->bp", [pwl, perms_one_hot]) loss_set /= n_src else: # batch = pwl.shape[0]; n_perm = idx.shape[0] # [batch, n_src!, n_src] : Pairwise losses for each permutation. pwl_set = pwl[:, torch.arange(n_src), idx.squeeze(-1)] # Apply reduce [batch, n_src!, n_src] --> [batch, n_src!] loss_set = perm_reduce(pwl_set, **kwargs) # Indexes and values of min losses for each batch element min_loss_idx = torch.argmin(loss_set, dim=1) min_loss, _ = torch.min(loss_set, dim=1, keepdim=True) return min_loss, min_loss_idx
[docs] @staticmethod def reorder_source(source, n_src, min_loss_idx): """ Reorder sources according to the best permutation. Args: source (torch.Tensor): Tensor of shape [batch, n_src, time] n_src (int): Number of sources. min_loss_idx (torch.LongTensor): Tensor of shape [batch], each item is in [0, n_src!). Returns: :class:`torch.Tensor`: Reordered sources of shape [batch, n_src, time]. MIT Copyright (c) 2018 Kaituo XU. See `Original code <>`__ and `License <>`__. """ perms = source.new_tensor(list(permutations(range(n_src))), dtype=torch.long) # Reorder estimate targets according the best permutation min_loss_perm = torch.index_select(perms, dim=0, index=min_loss_idx) # maybe use torch.gather()/index_select()/scatter() to impl this? reordered_sources = torch.zeros_like(source) for b in range(source.shape[0]): for c in range(n_src): reordered_sources[b, c] = source[b, min_loss_perm[b][c]] return reordered_sources