asteroid.losses.sinkpit_wrapper module¶
-
class
asteroid.losses.sinkpit_wrapper.
SinkPITLossWrapper
(loss_func, n_iter=200, hungarian_validation=True)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Permutation invariant loss wrapper.
Parameters: loss_func
computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \((batch, i, j)\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\) It evaluates an approximate value of the PIT loss using Sinkhorn’s iterative algorithm. Seebest_softperm_sinkhorn()
and http://arxiv.org/abs/2010.11871- Examples
>>> import torch >>> import pytorch_lightning as pl >>> from asteroid.losses import pairwise_neg_sisdr >>> sources = torch.randn(10, 3, 16000) >>> est_sources = torch.randn(10, 3, 16000) >>> # Compute SinkPIT loss based on pairwise losses >>> loss_func = SinkPITLossWrapper(pairwise_neg_sisdr) >>> loss_val = loss_func(est_sources, sources) >>> # A fixed temperature parameter `beta` (=10) is used >>> # unless a cooling callback is set. The value can be >>> # dynamically changed using a cooling callback module as follows. >>> model = NeuralNetworkModel() >>> optimizer = optim.Adam(model.parameters(), lr=1e-3) >>> dataset = YourDataset() >>> loader = data.DataLoader(dataset, batch_size=16) >>> system = System( >>> model, >>> optimizer, >>> loss_func=SinkPITLossWrapper(pairwise_neg_sisdr), >>> train_loader=loader, >>> val_loader=loader, >>> ) >>> >>> trainer = pl.Trainer( >>> max_epochs=100, >>> callbacks=[SinkPITBetaScheduler(lambda epoch : 1.02 ** epoch)], >>> ) >>> >>> trainer.fit(system)
-
forward
(est_targets, targets, return_est=False, **kwargs)[source]¶ Evaluate the loss using Sinkhorn’s algorithm.
Parameters: - est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
- targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of training targets
- return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: - Best permutation loss for each batch sample, average over
- the batch. torch.Tensor(loss_value)
- The reordered targets estimates if return_est is True.
- torch.Tensor of shape \((batch, nsrc, ...)\).
-
static
best_softperm_sinkhorn
(pair_wise_losses, beta=10, n_iter=200)[source]¶ Compute an approximate PIT loss using Sinkhorn’s algorithm. See http://arxiv.org/abs/2010.11871
Parameters: - pair_wise_losses (
torch.Tensor
) – Tensor of shape \((batch, n_src, n_src)\). Pairwise losses. - beta (float) – Inverse temperature parameter. (default = 10)
- n_iter (int) – Number of iteration. Even number. (default = 200)
Returns: torch.Tensor
– The loss corresponding to the best permutation of size (batch,).torch.Tensor
: A soft permutation matrix.
- pair_wise_losses (