Losses & Metrics¶
Permutation invariant training (PIT) made easy¶
Asteroid supports regular Permutation Invariant Training (PIT), it’s extension using Sinkhorn algorithm (SinkPIT) as well as Mixture Invariant Training (MixIT).
PIT¶
-
class
asteroid.losses.pit_wrapper.
PITLossWrapper
(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Permutation invariant loss wrapper.
Parameters: - loss_func – function with signature (est_targets, targets, **kwargs).
- pit_from (str) –
Determines how PIT is applied.
'pw_mtx'
(pairwise matrix): loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \((batch, i, j)\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\)'pw_pt'
(pairwise point): loss_func computes the loss for a batch of single source and single estimates (tensors won’t have the source axis). Output shape : \((batch)\). Seeget_pw_losses()
.'perm_avg'
(permutation average): loss_func computes the average loss for a given permutations of the sources and estimates. Output shape : \((batch)\). Seebest_perm_from_perm_avg_loss()
.
In terms of efficiency,
'perm_avg'
is the least efficicient. - perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) --> (B, n\_src!)\). perm_reduce can receive **kwargs during forward using the reduce_kwargs argument (dict). If those argument are static, consider defining a small function or using functools.partial. Only used in ‘pw_mtx’ and ‘pw_pt’ pit_from modes.
For each of these modes, the best permutation and reordering will be automatically computed. When either
'pw_mtx'
or'pw_pt'
is used, and the number of sources is larger than three, the hungarian algorithm is used to find the best permutation.- Examples
>>> import torch >>> from asteroid.losses import pairwise_neg_sisdr >>> sources = torch.randn(10, 3, 16000) >>> est_sources = torch.randn(10, 3, 16000) >>> # Compute PIT loss based on pairwise losses >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') >>> loss_val = loss_func(est_sources, sources) >>> >>> # Using reduce >>> def reduce(perm_loss, src): >>> weighted = perm_loss * src.norm(dim=-1, keepdim=True) >>> return torch.mean(weighted, dim=-1) >>> >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx', >>> perm_reduce=reduce) >>> reduce_kwargs = {'src': sources} >>> loss_val = loss_func(est_sources, sources, >>> reduce_kwargs=reduce_kwargs)
-
forward
(est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs)[source]¶ Find the best permutation and return the loss.
Parameters: - est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.
- targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets
- return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
- reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: - Best permutation loss for each batch sample, average over the batch.
- The reordered targets estimates if
return_est
is True.torch.Tensor
of shape $(batch, nsrc, …)$.
-
static
get_pw_losses
(loss_func, est_targets, targets, **kwargs)[source]¶ Get pair-wise losses between the training targets and its estimate for a given loss function.
Parameters: - loss_func – function with signature (est_targets, targets, **kwargs) The loss function to get pair-wise losses from.
- est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.
- targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets.
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: torch.Tensor or size $(batch, nsrc, nsrc)$, losses computed for all permutations of the targets and est_targets.
This function can be called on a loss function which returns a tensor of size \((batch)\). There are more efficient ways to compute pair-wise losses using broadcasting.
-
static
best_perm_from_perm_avg_loss
(loss_func, est_targets, targets, **kwargs)[source]¶ Find best permutation from loss function with source axis.
Parameters: - loss_func – function with signature $(est_targets, targets, **kwargs)$ The loss function batch losses from.
- est_targets – torch.Tensor. Expected shape $(batch, nsrc, *)$. The batch of target estimates.
- targets – torch.Tensor. Expected shape $(batch, nsrc, *)$. The batch of training targets.
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: torch.Tensor
– The loss corresponding to the best permutation of size $(batch,)$.torch.Tensor
:- The indices of the best permutations.
-
static
find_best_perm
(pair_wise_losses, perm_reduce=None, **kwargs)[source]¶ Find the best permutation, given the pair-wise losses.
Dispatch between factorial method if number of sources is small (<3) and hungarian method for more sources. If
perm_reduce
is not None, the factorial method is always used.Parameters: - pair_wise_losses (
torch.Tensor
) – Tensor of shape \((batch, n\_src, n\_src)\). Pairwise losses. - perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) -> (B, n\_src!)\)
- **kwargs – additional keyword argument that will be passed to the permutation reduce function.
Returns: torch.Tensor
– The loss corresponding to the best permutation of size $(batch,)$.torch.Tensor
: The indices of the best permutations.
- pair_wise_losses (
-
static
reorder_source
(source, batch_indices)[source]¶ Reorder sources according to the best permutation.
Parameters: - source (torch.Tensor) – Tensor of shape \((batch, n_src, time)\)
- batch_indices (torch.Tensor) – Tensor of shape \((batch, n_src)\). Contains optimal permutation indices for each batch.
Returns: torch.Tensor
– Reordered sources.
-
static
find_best_perm_factorial
(pair_wise_losses, perm_reduce=None, **kwargs)[source]¶ Find the best permutation given the pair-wise losses by looping through all the permutations.
Parameters: - pair_wise_losses (
torch.Tensor
) – Tensor of shape \((batch, n_src, n_src)\). Pairwise losses. - perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) -> (B, n\_src!)\)
- **kwargs – additional keyword argument that will be passed to the permutation reduce function.
Returns: torch.Tensor
– The loss corresponding to the best permutation of size $(batch,)$.torch.Tensor
: The indices of the best permutations.
MIT Copyright (c) 2018 Kaituo XU. See Original code and License.
- pair_wise_losses (
-
static
find_best_perm_hungarian
(pair_wise_losses: <sphinx.ext.autodoc.importer._MockObject object at 0x7f7f6f340110>)[source]¶ Find the best permutation given the pair-wise losses, using the Hungarian algorithm.
Returns: torch.Tensor
– The loss corresponding to the best permutation of size (batch,).torch.Tensor
: The indices of the best permutations.
-
class
asteroid.losses.pit_wrapper.
PITReorder
(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]¶ Bases:
asteroid.losses.pit_wrapper.PITLossWrapper
Permutation invariant reorderer. Only returns the reordered estimates. See :py:class:asteroid.losses.PITLossWrapper.
-
forward
(est_targets, targets, reduce_kwargs=None, **kwargs)[source]¶ Find the best permutation and return the loss.
Parameters: - est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.
- targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets
- return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
- reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: - Best permutation loss for each batch sample, average over the batch.
- The reordered targets estimates if
return_est
is True.torch.Tensor
of shape $(batch, nsrc, …)$.
-
MixIT¶
-
class
asteroid.losses.mixit_wrapper.
MixITLossWrapper
(loss_func, generalized=True)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Mixture invariant loss wrapper.
Parameters: - loss_func – function with signature (est_targets, targets, **kwargs).
- generalized (bool) – Determines how MixIT is applied. If False ,
apply MixIT for any number of mixtures as soon as they contain
the same number of sources (
best_part_mixit()
.) If True (default), apply MixIT for two mixtures, but those mixtures do not necessarly have to contain the same number of sources. Seebest_part_mixit_generalized()
.
For each of these modes, the best partition and reordering will be automatically computed.
Examples
>>> import torch >>> from asteroid.losses import multisrc_mse >>> mixtures = torch.randn(10, 2, 16000) >>> est_sources = torch.randn(10, 4, 16000) >>> # Compute MixIT loss based on pairwise losses >>> loss_func = MixITLossWrapper(multisrc_mse) >>> loss_val = loss_func(est_sources, mixtures)
- References
- [1] Scott Wisdom et al. “Unsupervised sound separation using mixtures of mixtures.” arXiv:2006.12701 (2020)
-
forward
(est_targets, targets, return_est=False, **kwargs)[source]¶ Find the best partition and return the loss.
Parameters: - est_targets – torch.Tensor. Expected shape \((batch, nsrc, *)\). The batch of target estimates.
- targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets
- return_est – Boolean. Whether to return the estimated mixtures estimates (To compute metrics or to save example).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: - Best partition loss for each batch sample, average over the batch. torch.Tensor(loss_value)
- The estimated mixtures (estimated sources summed according to the partition) if return_est is True. torch.Tensor of shape \((batch, nmix, ...)\).
-
static
best_part_mixit
(loss_func, est_targets, targets, **kwargs)[source]¶ Find best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm in [1]. Valid for any number of mixtures as soon as they contain the same number of sources.
Parameters: - loss_func – function with signature
(est_targets, targets, **kwargs)
The loss function to get batch losses from. - est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
- targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets (mixtures).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: torch.Tensor
– The loss corresponding to the best permutation of size (batch,).torch.LongTensor
: The indices of the best partition.list
: list of the possible partitions of the sources.
- loss_func – function with signature
-
static
best_part_mixit_generalized
(loss_func, est_targets, targets, **kwargs)[source]¶ Find best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm in [1]. Valid only for two mixtures, but those mixtures do not necessarly have to contain the same number of sources e.g the case where one mixture is silent is allowed..
Parameters: - loss_func – function with signature
(est_targets, targets, **kwargs)
The loss function to get batch losses from. - est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
- targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets (mixtures).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: torch.Tensor
– The loss corresponding to the best permutation of size (batch,).torch.LongTensor
: The indexes of the best permutations.list
: list of the possible partitions of the sources.
- loss_func – function with signature
-
static
loss_set_from_parts
(loss_func, est_targets, targets, parts, **kwargs)[source]¶ Common loop between both best_part_mixit
-
static
reorder_source
(est_targets, targets, min_loss_idx, parts)[source]¶ Reorder sources according to the best partition.
Parameters: - est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
- targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets.
- min_loss_idx – torch.LongTensor. The indexes of the best permutations.
- parts – list of the possible partitions of the sources.
Returns: torch.Tensor
– Reordered sources of shape \((batch, nmix, time)\).
SinkPIT¶
-
class
asteroid.losses.sinkpit_wrapper.
SinkPITLossWrapper
(loss_func, n_iter=200, hungarian_validation=True)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Permutation invariant loss wrapper.
Parameters: loss_func
computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \((batch, i, j)\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\) It evaluates an approximate value of the PIT loss using Sinkhorn’s iterative algorithm. Seebest_softperm_sinkhorn()
and http://arxiv.org/abs/2010.11871- Examples
>>> import torch >>> import pytorch_lightning as pl >>> from asteroid.losses import pairwise_neg_sisdr >>> sources = torch.randn(10, 3, 16000) >>> est_sources = torch.randn(10, 3, 16000) >>> # Compute SinkPIT loss based on pairwise losses >>> loss_func = SinkPITLossWrapper(pairwise_neg_sisdr) >>> loss_val = loss_func(est_sources, sources) >>> # A fixed temperature parameter `beta` (=10) is used >>> # unless a cooling callback is set. The value can be >>> # dynamically changed using a cooling callback module as follows. >>> model = NeuralNetworkModel() >>> optimizer = optim.Adam(model.parameters(), lr=1e-3) >>> dataset = YourDataset() >>> loader = data.DataLoader(dataset, batch_size=16) >>> system = System( >>> model, >>> optimizer, >>> loss_func=SinkPITLossWrapper(pairwise_neg_sisdr), >>> train_loader=loader, >>> val_loader=loader, >>> ) >>> >>> trainer = pl.Trainer( >>> max_epochs=100, >>> callbacks=[SinkPITBetaScheduler(lambda epoch : 1.02 ** epoch)], >>> ) >>> >>> trainer.fit(system)
-
forward
(est_targets, targets, return_est=False, **kwargs)[source]¶ Evaluate the loss using Sinkhorn’s algorithm.
Parameters: - est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
- targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of training targets
- return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
- **kwargs – additional keyword argument that will be passed to the loss function.
Returns: - Best permutation loss for each batch sample, average over
- the batch. torch.Tensor(loss_value)
- The reordered targets estimates if return_est is True.
- torch.Tensor of shape \((batch, nsrc, ...)\).
-
static
best_softperm_sinkhorn
(pair_wise_losses, beta=10, n_iter=200)[source]¶ Compute an approximate PIT loss using Sinkhorn’s algorithm. See http://arxiv.org/abs/2010.11871
Parameters: - pair_wise_losses (
torch.Tensor
) – Tensor of shape \((batch, n_src, n_src)\). Pairwise losses. - beta (float) – Inverse temperature parameter. (default = 10)
- n_iter (int) – Number of iteration. Even number. (default = 200)
Returns: torch.Tensor
– The loss corresponding to the best permutation of size (batch,).torch.Tensor
: A soft permutation matrix.
- pair_wise_losses (
Available loss functions¶
PITLossWrapper
supports three types of loss function. For “easy” losses,
we implement the three types (pairwise point, single-source loss and multi-source loss).
For others, we only implement the single-source loss which can be aggregated
into both PIT and nonPIT training.
MSE¶
-
asteroid.losses.mse.
PairwiseMSE
(*args, **kwargs)[source]¶ Measure pairwise mean square error on a batch.
- Shape:
- est_targets : \((batch, nsrc, ...)\).
- targets: \((batch, nsrc, ...)\).
Returns: torch.Tensor
– with shape \((batch, nsrc, nsrc)\)- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(PairwiseMSE(), pit_from='pairwise') >>> loss = loss_func(est_targets, targets)
-
asteroid.losses.mse.
SingleSrcMSE
(*args, **kwargs)[source]¶ Measure mean square error on a batch. Supports both tensors with and without source axis.
- Shape:
- est_targets: \((batch, ...)\).
- targets: \((batch, ...)\).
Returns: torch.Tensor
– with shape \((batch)\)- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'. >>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
-
asteroid.losses.mse.
MultiSrcMSE
(*args, **kwargs)[source]¶ Measure mean square error on a batch. Supports both tensors with and without source axis.
- Shape:
- est_targets: \((batch, ...)\).
- targets: \((batch, ...)\).
Returns: torch.Tensor
– with shape \((batch)\)- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'. >>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
SDR¶
-
asteroid.losses.sdr.
PairwiseNegSDR
(*args, **kwargs)[source]¶ Base class for pairwise negative SI-SDR, SD-SDR and SNR on a batch.
Parameters: - Shape:
- est_targets : \((batch, nsrc, ...)\).
- targets: \((batch, nsrc, ...)\).
Returns: torch.Tensor
– with shape \((batch, nsrc, nsrc)\). Pairwise losses.- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(PairwiseNegSDR("sisdr"), >>> pit_from='pairwise') >>> loss = loss_func(est_targets, targets)
- References
- [1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.
-
asteroid.losses.sdr.
SingleSrcNegSDR
(*args, **kwargs)[source]¶ Base class for single-source negative SI-SDR, SD-SDR and SNR.
Parameters: - sdr_type (str) – choose between
snr
for plain SNR,sisdr
for SI-SDR andsdsdr
for SD-SDR [1]. - zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.
- take_log (bool, optional) – by default the log10 of sdr is returned.
- reduction (string, optional) – Specifies the reduction to apply to
the output:
'none'
|'mean'
.'none'
: no reduction will be applied,'mean'
: the sum of the output will be divided by the number of elements in the output.
- Shape:
- est_targets : \((batch, time)\).
- targets: \((batch, time)\).
Returns: torch.Tensor
– with shape \((batch)\) ifreduction='none'
else [] scalar ifreduction='mean'
.- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(SingleSrcNegSDR("sisdr"), >>> pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
- References
- [1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.
- sdr_type (str) – choose between
-
asteroid.losses.sdr.
MultiSrcNegSDR
(*args, **kwargs)[source]¶ Base class for computing negative SI-SDR, SD-SDR and SNR for a given permutation of source and their estimates.
Parameters: - Shape:
- est_targets : \((batch, nsrc, time)\).
- targets: \((batch, nsrc, time)\).
Returns: torch.Tensor
– with shape \((batch)\) ifreduction='none'
else [] scalar ifreduction='mean'
.- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(MultiSrcNegSDR("sisdr"), >>> pit_from='perm_avg') >>> loss = loss_func(est_targets, targets)
- References
- [1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.
PMSQE¶
-
asteroid.losses.pmsqe.
SingleSrcPMSQE
(*args, **kwargs)[source]¶ Computes the Perceptual Metric for Speech Quality Evaluation (PMSQE) as described in [1]. This version is only designed for 16 kHz (512 length DFT). Adaptation to 8 kHz could be done by changing the parameters of the class (see Tensorflow implementation). The SLL, frequency and gain equalization are applied in each sequence independently.
Parameters: - window_name (str) – Select the used window function for the correct factor to be applied. Defaults to sqrt hanning window. Among [‘rect’, ‘hann’, ‘sqrt_hann’, ‘hamming’, ‘flatTop’].
- window_weight (float, optional) – Correction to the window factor applied.
- bark_eq (bool, optional) – Whether to apply bark equalization.
- gain_eq (bool, optional) – Whether to apply gain equalization.
- sample_rate (int) – Sample rate of the input audio.
- References
[1] J.M.Martin, A.M.Gomez, J.A.Gonzalez, A.M.Peinado ‘A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality’, IEEE Signal Processing Letters, 2018. Implemented by Juan M. Martin. Contact: mdjuamart@ugr.es
Copyright 2019: University of Granada, Signal Processing, Multimedia Transmission and Speech/Audio Technologies (SigMAT) Group.
Note
Inspired on the Perceptual Evaluation of the Speech Quality (PESQ) algorithm, this function consists of two regularization factors : the symmetrical and asymmetrical distortion in the loudness domain.
- Examples
>>> import torch >>> from asteroid_filterbanks import STFTFB, Encoder, transforms >>> from asteroid.losses import PITLossWrapper, SingleSrcPMSQE >>> stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)) >>> # Usage by itself >>> ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000) >>> ref_spec = transforms.mag(stft(ref)) >>> est_spec = transforms.mag(stft(est)) >>> loss_func = SingleSrcPMSQE() >>> loss_value = loss_func(est_spec, ref_spec) >>> # Usage with PITLossWrapper >>> loss_func = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt') >>> ref, est = torch.randn(2, 3, 16000), torch.randn(2, 3, 16000) >>> ref_spec = transforms.mag(stft(ref)) >>> est_spec = transforms.mag(stft(est)) >>> loss_value = loss_func(ref_spec, est_spec)
STOI¶
MultiScale Spectral Loss¶
-
asteroid.losses.multi_scale_spectral.
SingleSrcMultiScaleSpectral
(*args, **kwargs)[source]¶ Measure multi-scale spectral loss as described in [1]
Parameters: - Shape:
- est_targets : \((batch, time)\).
- targets: \((batch, time)\).
Returns: torch.Tensor
– with shape [batch]- Examples
>>> import torch >>> targets = torch.randn(10, 32000) >>> est_targets = torch.randn(10, 32000) >>> # Using it by itself on a pair of source/estimate >>> loss_func = SingleSrcMultiScaleSpectral() >>> loss = loss_func(est_targets, targets)
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # Using it with PITLossWrapper with sets of source/estimates >>> loss_func = PITLossWrapper(SingleSrcMultiScaleSpectral(), >>> pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
- References
- [1] Jesse Engel and Lamtharn (Hanoi) Hantrakul and Chenjie Gu and Adam Roberts “DDSP: Differentiable Digital Signal Processing” ICLR 2020.
Deep clustering (Affinity) loss¶
-
asteroid.losses.cluster.
deep_clustering_loss
(embedding, tgt_index, binary_mask=None)[source]¶ Compute the deep clustering loss defined in [1].
Parameters: - embedding (torch.Tensor) – Estimated embeddings. Expected shape \((batch, frequency * frame, embedding\_dim)\).
- tgt_index (torch.Tensor) – Dominating source index in each TF bin. Expected shape: \((batch, frequency, frame)\).
- binary_mask (torch.Tensor) – VAD in TF plane. Bool or Float. See asteroid.dsp.vad.ebased_vad.
Returns: torch.Tensor. Deep clustering loss for every batch sample.
- Examples
>>> import torch >>> from asteroid.losses.cluster import deep_clustering_loss >>> spk_cnt = 3 >>> embedding = torch.randn(10, 5*400, 20) >>> targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt) >>> loss = deep_clustering_loss(embedding, targets)
- Reference
- [1] Zhong-Qiu Wang, Jonathan Le Roux, John R. Hershey “ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING”
Note
Be careful in viewing the embedding tensors. The target indices
tgt_index
are of shape \((batch, freq, frames)\). Even if the embedding is of shape \((batch, freq * frames, emb)\), the underlying view should be \((batch, freq, frames, emb)\) and not \((batch, frames, freq, emb)\).
Computing metrics¶
-
asteroid.metrics.
get_metrics
(mix, clean, estimate, sample_rate=16000, metrics_list='all', average=True, compute_permutation=False, ignore_metrics_errors=False, filename=None)[source]¶ Get speech separation/enhancement metrics from mix/clean/estimate.
Parameters: - mix (np.array) – mixture array.
- clean (np.array) – reference array.
- estimate (np.array) – estimate array.
- sample_rate (int) – sampling rate of the audio clips.
- metrics_list (Union[List[str], str) – List of metrics to compute. Defaults to ‘all’ ([‘si_sdr’, ‘sdr’, ‘sir’, ‘sar’, ‘stoi’, ‘pesq’]).
- average (bool) – Return dict([float]) if True, else dict([array]).
- compute_permutation (bool) – Whether to compute the permutation on estimate sources for the output metrics (default False)
- ignore_metrics_errors (bool) – Whether to ignore errors that occur in computing the metrics. A warning will be printed instead.
- filename (str, optional) – If computing a metric fails, print this filename along with the exception/warning message for debugging purposes.
- Shape:
- mix: \((D, N)\) or (N, ).
- clean: \((K\_source, N)\) or (N, ).
- estimate: \((K\_target, N)\) or (N, ).
Returns: dict – Dictionary with all requested metrics, with ‘input_’ prefix for metrics at the input (mixture against clean), no prefix at the output (estimate against clean). Output format depends on average. - Examples
>>> import numpy as np >>> import pprint >>> from asteroid.metrics import get_metrics >>> mix = np.random.randn(1, 16000) >>> clean = np.random.randn(2, 16000) >>> est = np.random.randn(2, 16000) >>> metrics_dict = get_metrics(mix, clean, est, sample_rate=8000, ... metrics_list='all') >>> pprint.pprint(metrics_dict) {'input_pesq': 1.924380898475647, 'input_sar': -11.67667585294225, 'input_sdr': -14.88667106190552, 'input_si_sdr': -52.43849784881705, 'input_sir': -0.10419427290163795, 'input_stoi': 0.015112115177091223, 'pesq': 1.7713886499404907, 'sar': -11.610963379923195, 'sdr': -14.527246041125844, 'si_sdr': -46.26557128489802, 'sir': 0.4799929272243427, 'stoi': 0.022023073540350643}