Shortcuts

asteroid.losses.sinkpit_wrapper module

class asteroid.losses.sinkpit_wrapper.SinkPITLossWrapper(loss_func, n_iter=200, hungarian_validation=True)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Permutation invariant loss wrapper.

Parameters:
  • loss_func – function with signature (targets, est_targets, **kwargs).
  • n_iter (int) – number of the Sinkhorn iteration (default = 200). Supposed to be an even number.
  • hungarian_validation (boolean) – Whether to use the Hungarian algorithm for the validation. (default = True)

loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \((batch, i, j)\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\) It evaluates an approximate value of the PIT loss using Sinkhorn’s iterative algorithm. See best_softperm_sinkhorn() and http://arxiv.org/abs/2010.11871

Examples
>>> import torch
>>> import pytorch_lightning as pl
>>> from asteroid.losses import pairwise_neg_sisdr
>>> sources = torch.randn(10, 3, 16000)
>>> est_sources = torch.randn(10, 3, 16000)
>>> # Compute SinkPIT loss based on pairwise losses
>>> loss_func = SinkPITLossWrapper(pairwise_neg_sisdr)
>>> loss_val = loss_func(est_sources, sources)
>>> # A fixed temperature parameter `beta` (=10) is used
>>> # unless a cooling callback is set. The value can be
>>> # dynamically changed using a cooling callback module as follows.
>>> model = NeuralNetworkModel()
>>> optimizer = optim.Adam(model.parameters(), lr=1e-3)
>>> dataset = YourDataset()
>>> loader = data.DataLoader(dataset, batch_size=16)
>>> system = System(
>>>     model,
>>>     optimizer,
>>>     loss_func=SinkPITLossWrapper(pairwise_neg_sisdr),
>>>     train_loader=loader,
>>>     val_loader=loader,
>>>     )
>>>
>>> trainer = pl.Trainer(
>>>     max_epochs=100,
>>>     callbacks=[SinkPITBetaScheduler(lambda epoch : 1.02 ** epoch)],
>>>     )
>>>
>>> trainer.fit(system)
beta[source]
forward(est_targets, targets, return_est=False, **kwargs)[source]

Evaluate the loss using Sinkhorn’s algorithm.

Parameters:
  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
  • targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of training targets
  • return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

  • Best permutation loss for each batch sample, average over
    the batch. torch.Tensor(loss_value)
  • The reordered targets estimates if return_est is True.
    torch.Tensor of shape \((batch, nsrc, ...)\).

static best_softperm_sinkhorn(pair_wise_losses, beta=10, n_iter=200)[source]

Compute an approximate PIT loss using Sinkhorn’s algorithm. See http://arxiv.org/abs/2010.11871

Parameters:
  • pair_wise_losses (torch.Tensor) – Tensor of shape \((batch, n_src, n_src)\). Pairwise losses.
  • beta (float) – Inverse temperature parameter. (default = 10)
  • n_iter (int) – Number of iteration. Even number. (default = 200)
Returns:

  • torch.Tensor – The loss corresponding to the best permutation of size (batch,).
  • torch.Tensor: A soft permutation matrix.

Read the Docs v: v0.4.4
Versions
latest
stable
v0.4.4
v0.4.3
v0.4.2
v0.4.1
v0.4.0
v0.3.5_b
v0.3.4
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.