Shortcuts

asteroid.losses.mse module

class asteroid.losses.mse.PairwiseMSE(*args, **kwargs)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Measure pairwise mean square error on a batch.

Shape:
  • est_targets : \((batch, nsrc, ...)\).
  • targets: \((batch, nsrc, ...)\).
Returns:torch.Tensor – with shape \((batch, nsrc, nsrc)\)
Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(PairwiseMSE(), pit_from='pairwise')
>>> loss = loss_func(est_targets, targets)
forward(est_targets, targets)[source]
class asteroid.losses.mse.SingleSrcMSE(*args, **kwargs)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Measure mean square error on a batch. Supports both tensors with and without source axis.

Shape:
  • est_targets: \((batch, ...)\).
  • targets: \((batch, ...)\).
Returns:torch.Tensor – with shape \((batch)\)
Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'.
>>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)
forward(est_targets, targets)[source]
asteroid.losses.mse.MultiSrcMSE[source]

alias of asteroid.losses.mse.SingleSrcMSE

asteroid.losses.mse.pairwise_mse[source]

Measure pairwise mean square error on a batch.

Shape:
  • est_targets : \((batch, nsrc, ...)\).
  • targets: \((batch, nsrc, ...)\).
Returns:torch.Tensor – with shape \((batch, nsrc, nsrc)\)
Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(PairwiseMSE(), pit_from='pairwise')
>>> loss = loss_func(est_targets, targets)
asteroid.losses.mse.singlesrc_mse[source]

Measure mean square error on a batch. Supports both tensors with and without source axis.

Shape:
  • est_targets: \((batch, ...)\).
  • targets: \((batch, ...)\).
Returns:torch.Tensor – with shape \((batch)\)
Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'.
>>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)
asteroid.losses.mse.multisrc_mse[source]

Measure mean square error on a batch. Supports both tensors with and without source axis.

Shape:
  • est_targets: \((batch, ...)\).
  • targets: \((batch, ...)\).
Returns:torch.Tensor – with shape \((batch)\)
Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'.
>>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)
Read the Docs v: v0.4.4
Versions
latest
stable
v0.4.4
v0.4.3
v0.4.2
v0.4.1
v0.4.0
v0.3.5_b
v0.3.4
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.