Source code for asteroid.losses.mse

from ..utils.deprecation_utils import DeprecationMixin
from torch.nn.modules.loss import _Loss


[docs]class PairwiseMSE(_Loss): """ Measure pairwise mean square error on a batch. Shape: est_targets (:class:`torch.Tensor`): Expected shape [batch, nsrc, *]. The batch of target estimates. targets (:class:`torch.Tensor`): Expected shape [batch, nsrc, *]. The batch of training targets Returns: :class:`torch.Tensor`: with shape [batch, nsrc, nsrc] Examples: >>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(PairwiseMSE(), pit_from='pairwise') >>> loss = loss_func(est_targets, targets) """
[docs] def forward(self, est_targets, targets): targets = targets.unsqueeze(1) est_targets = est_targets.unsqueeze(2) pw_loss = (targets - est_targets) ** 2 # Need to return [batch, nsrc, nsrc] mean_over = list(range(3, pw_loss.ndim)) return pw_loss.mean(dim=mean_over)
[docs]class SingleSrcMSE(_Loss): """ Measure mean square error on a batch. Supports both tensors with and without source axis. Shape: est_targets (:class:`torch.Tensor`): Expected shape [batch, *]. The batch of target estimates. targets (:class:`torch.Tensor`): Expected shape [batch, *]. The batch of training targets. Returns: :class:`torch.Tensor`: with shape [batch] Examples: >>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'. >>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt') >>> loss = loss_func(est_targets, targets) """
[docs] def forward(self, est_targets, targets): loss = (targets - est_targets) ** 2 mean_over = list(range(1, loss.ndim)) return loss.mean(dim=mean_over)
# aliases MultiSrcMSE = SingleSrcMSE pairwise_mse = PairwiseMSE() singlesrc_mse = SingleSrcMSE() multisrc_mse = MultiSrcMSE() # Legacy
[docs]class NoSrcMSE(SingleSrcMSE, DeprecationMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.warn_deprecated()
NonPitMSE = NoSrcMSE nosrc_mse = singlesrc_mse nonpit_mse = multisrc_mse