asteroid.losses.mse module¶
-
class
asteroid.losses.mse.
PairwiseMSE
(*args, **kwargs)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Measure pairwise mean square error on a batch.
- Shape:
- est_targets : \((batch, nsrc, ...)\).
- targets: \((batch, nsrc, ...)\).
Returns: torch.Tensor
– with shape \((batch, nsrc, nsrc)\)- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(PairwiseMSE(), pit_from='pairwise') >>> loss = loss_func(est_targets, targets)
-
class
asteroid.losses.mse.
SingleSrcMSE
(*args, **kwargs)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Measure mean square error on a batch. Supports both tensors with and without source axis.
- Shape:
- est_targets: \((batch, ...)\).
- targets: \((batch, ...)\).
Returns: torch.Tensor
– with shape \((batch)\)- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'. >>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
-
asteroid.losses.mse.
MultiSrcMSE
[source]¶ alias of
asteroid.losses.mse.SingleSrcMSE
-
asteroid.losses.mse.
pairwise_mse
[source]¶ Measure pairwise mean square error on a batch.
- Shape:
- est_targets : \((batch, nsrc, ...)\).
- targets: \((batch, nsrc, ...)\).
Returns: torch.Tensor
– with shape \((batch, nsrc, nsrc)\)- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(PairwiseMSE(), pit_from='pairwise') >>> loss = loss_func(est_targets, targets)
-
asteroid.losses.mse.
singlesrc_mse
[source]¶ Measure mean square error on a batch. Supports both tensors with and without source axis.
- Shape:
- est_targets: \((batch, ...)\).
- targets: \((batch, ...)\).
Returns: torch.Tensor
– with shape \((batch)\)- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'. >>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
-
asteroid.losses.mse.
multisrc_mse
[source]¶ Measure mean square error on a batch. Supports both tensors with and without source axis.
- Shape:
- est_targets: \((batch, ...)\).
- targets: \((batch, ...)\).
Returns: torch.Tensor
– with shape \((batch)\)- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'. >>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)