Losses & Metrics¶

class
asteroid.losses.
PITLossWrapper
(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Permutation invariant loss wrapper.
Parameters:  loss_func – function with signature (targets, est_targets, **kwargs).
 pit_from (str) –
Determines how PIT is applied.
'pw_mtx'
(pairwise matrix): loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \([batch, i, j]\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\)'pw_pt'
(pairwise point): loss_func computes the loss for a batch of single source and single estimates (tensors won’t have the source axis). Output shape : \((batch)\). Seeget_pw_losses()
. ``’perm_avg’``(permutation average): loss_func computes the
average loss for a given permutations of the sources and
estimates. Output shape : \((batch)\).
See
best_perm_from_perm_avg_loss()
.
In terms of efficiency,
'perm_avg'
is the least efficicient.  perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) –> (B, n_src!). perm_reduce can receive **kwargs during forward using the reduce_kwargs argument (dict). If those argument are static, consider defining a small function or using functools.partial. Only used in ‘pw_mtx’ and ‘pw_pt’ pit_from modes.
For each of these modes, the best permutation and reordering will be automatically computed.
Examples
>>> import torch >>> from asteroid.losses import pairwise_neg_sisdr >>> sources = torch.randn(10, 3, 16000) >>> est_sources = torch.randn(10, 3, 16000) >>> # Compute PIT loss based on pairwise losses >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') >>> loss_val = loss_func(est_sources, sources) >>> >>> # Using reduce >>> def reduce(perm_loss, src): >>> weighted = perm_loss * src.norm(dim=1, keepdim=True) >>> return torch.mean(weighted, dim=1) >>> >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx', >>> perm_reduce=reduce) >>> reduce_kwargs = {'src': sources} >>> loss_val = loss_func(est_sources, sources, >>> reduce_kwargs=reduce_kwargs)

static
best_perm_from_perm_avg_loss
(loss_func, est_targets, targets, **kwargs)[source]¶ Find best permutation from loss function with source axis.
Parameters:  loss_func – function with signature (targets, est_targets, **kwargs) The loss function batch losses from.
 est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
 targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets.
 **kwargs – additional keyword argument that will be passed to the loss function.
Returns: tuple –
torch.Tensor
: The loss corresponding to the best permutation of size (batch,).torch.LongTensor
: The indexes of the best permutations.

static
find_best_perm
(pair_wise_losses, n_src, perm_reduce=None, **kwargs)[source]¶ Find the best permutation, given the pairwise losses.
Parameters:  pair_wise_losses (
torch.Tensor
) – Tensor of shape [batch, n_src, n_src]. Pairwise losses.  n_src (int) – Number of sources.
 perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) –> (B, n_src!)
 **kwargs – additional keyword argument that will be passed to the permutation reduce function.
Returns: tuple –
torch.Tensor
: The loss corresponding to the best permutation of size (batch,).torch.LongTensor
: The indexes of the best permutations.MIT Copyright (c) 2018 Kaituo XU. See Original code and License.
 pair_wise_losses (

forward
(est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs)[source]¶ Find the best permutation and return the loss.
Parameters:  est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
 targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets
 return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
 reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
 **kwargs – additional keyword argument that will be passed to the loss function.
Returns:  Best permutation loss for each batch sample, average over
 the batch. torch.Tensor(loss_value)
 The reordered targets estimates if return_est is True.
 torch.Tensor of shape [batch, nsrc, *].

static
get_pw_losses
(loss_func, est_targets, targets, **kwargs)[source]¶ Get pairwise losses between the training targets and its estimate for a given loss function.
Parameters:  loss_func – function with signature (targets, est_targets, **kwargs) The loss function to get pairwise losses from.
 est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
 targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets.
 **kwargs – additional keyword argument that will be passed to the loss function.
Returns: torch.Tensor or size [batch, nsrc, nsrc], losses computed for all permutations of the targets and est_targets.
This function can be called on a loss function which returns a tensor of size [batch]. There are more efficient ways to compute pairwise losses using broadcasting.

static
reorder_source
(source, n_src, min_loss_idx)[source]¶ Reorder sources according to the best permutation.
Parameters:  source (torch.Tensor) – Tensor of shape [batch, n_src, time]
 n_src (int) – Number of sources.
 min_loss_idx (torch.LongTensor) – Tensor of shape [batch], each item is in [0, n_src!).
Returns: torch.Tensor
– Reordered sources of shape [batch, n_src, time].MIT Copyright (c) 2018 Kaituo XU. See Original code and License.

class
asteroid.losses.
SingleSrcPMSQE
(window_name='sqrt_hann', window_weight=1.0, bark_eq=True, gain_eq=True, sample_rate=16000)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Computes the Perceptual Metric for Speech Quality Evaluation (PMSQE) as described in [1]. This version is only designed for 16 kHz (512 length DFT). Adaptation to 8 kHz could be done by changing the parameters of the class (see Tensorflow implementation). The SLL, frequency and gain equalization are applied in each sequence independently.
Parameters:  window_name (str) – Select the used window function for the correct factor to be applied. Defaults to sqrt hanning window. Among [‘rect’, ‘hann’, ‘sqrt_hann’, ‘hamming’, ‘flatTop’].
 window_weight (float, optional) – Correction to the window factor applied.
 bark_eq (bool, optional) – Whether to apply bark equalization.
 gain_eq (bool, optional) – Whether to apply gain equalization.
 sample_rate (int) – Sample rate of the input audio.
References
[1] J.M.Martin, A.M.Gomez, J.A.Gonzalez, A.M.Peinado ‘A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality’, IEEE Signal Processing Letters, 2018. Implemented by Juan M. Martin. Contact: mdjuamart@ugr.es Copyright 2019: University of Granada, Signal Processing, Multimedia Transmission and Speech/Audio Technologies (SigMAT) Group.
Note
Inspired on the Perceptual Evaluation of the Speech Quality (PESQ) algorithm, this function consists of two regularization factors : the symmetrical and asymmetrical distortion in the loudness domain.
Examples
>>> import torch >>> from asteroid.filterbanks import STFTFB, Encoder, transforms >>> from asteroid.losses import PITLossWrapper, SingleSrcPMSQE >>> stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)) >>> # Usage by itself >>> ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000) >>> ref_spec = transforms.take_mag(stft(ref)) >>> est_spec = transforms.take_mag(stft(est)) >>> loss_func = SingleSrcPMSQE() >>> loss_value = loss_func(est_spec, ref_spec) >>> # Usage with PITLossWrapper >>> loss_func = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt') >>> ref, est = torch.randn(2, 3, 16000), torch.randn(2, 3, 16000) >>> ref_spec = transforms.take_mag(stft(ref)) >>> est_spec = transforms.take_mag(stft(est)) >>> loss_value = loss_func(ref_spec, est_spec)

bark_freq_equalization
(ref_bark_spectra, deg_bark_spectra)[source]¶ This version is applied in the degraded directly.

forward
(est_targets, targets, pad_mask=None)[source]¶  Args
 est_targets (torch.Tensor): Dimensions (B, T, F).
 Padded degraded power spectrum in timefrequency domain.
 targets (torch.Tensor): Dimensions (B, T, F).
 ZeroPadded reference power spectrum in timefrequency domain.
 pad_mask (torch.Tensor, optional): Dimensions (B, T, 1). Mask
 to indicate the padding frames. Defaults to all ones.
 Dimensions
 B: Number of sequences in the batch. T: Number of time frames. F: Number of frequency bins.
 Returns
 torch.tensor of shape (B, ), wD + 0.309 * wDA
 Notes
 Dimensions (B, F, T) are also supported by SingleSrcPMSQE but are less efficient because input tensors are transposed (not inplace).
Examples

asteroid.losses.
SingleSrcNegSTOI
¶ alias of
asteroid.losses.stoi.NegSTOILoss

class
asteroid.losses.
SingleSrcMultiScaleSpectral
(n_filters=None, windows_size=None, hops_size=None, alpha=1.0)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Measure multiscale spectral loss as described in [1]
Parameters:  Shape:
 est_targets (
torch.Tensor
): Expected shape [batch, time].  Batch of target estimates.
 targets (
torch.Tensor
): Expected shape [batch, time].  Batch of training targets.
alpha (float) : Weighting factor for the log term
 est_targets (
Returns: torch.Tensor
– with shape [batch]Examples
>>> import torch >>> targets = torch.randn(10, 32000) >>> est_targets = torch.randn(10, 32000) >>> # Using it by itself on a pair of source/estimate >>> loss_func = SingleSrcMultiScaleSpectral() >>> loss = loss_func(est_targets, targets)
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # Using it with PITLossWrapper with sets of source/estimates >>> loss_func = PITLossWrapper(SingleSrcMultiScaleSpectral(), >>> pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
References
[1] Jesse Engel and Lamtharn (Hanoi) Hantrakul and Chenjie Gu and Adam Roberts DDSP: Differentiable Digital Signal Processing International Conference on Learning Representations ICLR 2020 $

class
asteroid.losses.
PairwiseNegSDR
(sdr_type, zero_mean=True, take_log=True)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Base class for pairwise negative SISDR, SDSDR and SNR on a batch.
Parameters:  Shape:
 est_targets (
torch.Tensor
): Expected shape  [batch, n_src, time]. Batch of target estimates.
 targets (
torch.Tensor
): Expected shape  [batch, n_src, time]. Batch of training targets.
 est_targets (
Returns: torch.Tensor
– with shape [batch, n_src, n_src]. Pairwise losses.Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(PairwiseNegSDR("sisdr"), >>> pit_from='pairwise') >>> loss = loss_func(est_targets, targets)
References
[1] Le Roux, Jonathan, et al. “SDR halfbaked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

asteroid.losses.
deep_clustering_loss
(embedding, tgt_index, binary_mask=None)[source]¶ Compute the deep clustering loss defined in [1].
Parameters:  embedding (torch.Tensor) – Estimated embeddings. Expected shape (batch, frequency x frame, embedding_dim)
 tgt_index (torch.Tensor) – Dominating source index in each TF bin. Expected shape: [batch, frequency, frame]
 binary_mask (torch.Tensor) – VAD in TF plane. Bool or Float. See asteroid.filterbanks.transforms.ebased_vad.
Returns: torch.Tensor. Deep clustering loss for every batch sample.
 Examples
>>> import torch >>> from asteroid.losses.cluster import deep_clustering_loss >>> spk_cnt = 3 >>> embedding = torch.randn(10, 5*400, 20) >>> targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt) >>> loss = deep_clustering_loss(embedding, targets)
 Reference
 [1] ZhongQiu Wang, Jonathan Le Roux, John R. Hershey
 “ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING”
Note
Be careful in viewing the embedding tensors. The target indices tgt_index are of shape (batch, freq, frames). Even if the embedding is of shape (batch, freq*frames, emb), the underlying view should be (batch, freq, frames, emb) and not (batch, frames, freq, emb).
Permutation invariant training (PIT) made easy¶

class
asteroid.losses.pit_wrapper.
PITLossWrapper
(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Permutation invariant loss wrapper.
Parameters:  loss_func – function with signature (targets, est_targets, **kwargs).
 pit_from (str) –
Determines how PIT is applied.
'pw_mtx'
(pairwise matrix): loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \([batch, i, j]\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\)'pw_pt'
(pairwise point): loss_func computes the loss for a batch of single source and single estimates (tensors won’t have the source axis). Output shape : \((batch)\). Seeget_pw_losses()
. ``’perm_avg’``(permutation average): loss_func computes the
average loss for a given permutations of the sources and
estimates. Output shape : \((batch)\).
See
best_perm_from_perm_avg_loss()
.
In terms of efficiency,
'perm_avg'
is the least efficicient.  perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) –> (B, n_src!). perm_reduce can receive **kwargs during forward using the reduce_kwargs argument (dict). If those argument are static, consider defining a small function or using functools.partial. Only used in ‘pw_mtx’ and ‘pw_pt’ pit_from modes.
For each of these modes, the best permutation and reordering will be automatically computed.
Examples
>>> import torch >>> from asteroid.losses import pairwise_neg_sisdr >>> sources = torch.randn(10, 3, 16000) >>> est_sources = torch.randn(10, 3, 16000) >>> # Compute PIT loss based on pairwise losses >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') >>> loss_val = loss_func(est_sources, sources) >>> >>> # Using reduce >>> def reduce(perm_loss, src): >>> weighted = perm_loss * src.norm(dim=1, keepdim=True) >>> return torch.mean(weighted, dim=1) >>> >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx', >>> perm_reduce=reduce) >>> reduce_kwargs = {'src': sources} >>> loss_val = loss_func(est_sources, sources, >>> reduce_kwargs=reduce_kwargs)

static
best_perm_from_perm_avg_loss
(loss_func, est_targets, targets, **kwargs)[source]¶ Find best permutation from loss function with source axis.
Parameters:  loss_func – function with signature (targets, est_targets, **kwargs) The loss function batch losses from.
 est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
 targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets.
 **kwargs – additional keyword argument that will be passed to the loss function.
Returns: tuple –
torch.Tensor
: The loss corresponding to the best permutation of size (batch,).torch.LongTensor
: The indexes of the best permutations.

static
find_best_perm
(pair_wise_losses, n_src, perm_reduce=None, **kwargs)[source]¶ Find the best permutation, given the pairwise losses.
Parameters:  pair_wise_losses (
torch.Tensor
) – Tensor of shape [batch, n_src, n_src]. Pairwise losses.  n_src (int) – Number of sources.
 perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) –> (B, n_src!)
 **kwargs – additional keyword argument that will be passed to the permutation reduce function.
Returns: tuple –
torch.Tensor
: The loss corresponding to the best permutation of size (batch,).torch.LongTensor
: The indexes of the best permutations.MIT Copyright (c) 2018 Kaituo XU. See Original code and License.
 pair_wise_losses (

forward
(est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs)[source]¶ Find the best permutation and return the loss.
Parameters:  est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
 targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets
 return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
 reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
 **kwargs – additional keyword argument that will be passed to the loss function.
Returns:  Best permutation loss for each batch sample, average over
 the batch. torch.Tensor(loss_value)
 The reordered targets estimates if return_est is True.
 torch.Tensor of shape [batch, nsrc, *].

static
get_pw_losses
(loss_func, est_targets, targets, **kwargs)[source]¶ Get pairwise losses between the training targets and its estimate for a given loss function.
Parameters:  loss_func – function with signature (targets, est_targets, **kwargs) The loss function to get pairwise losses from.
 est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
 targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets.
 **kwargs – additional keyword argument that will be passed to the loss function.
Returns: torch.Tensor or size [batch, nsrc, nsrc], losses computed for all permutations of the targets and est_targets.
This function can be called on a loss function which returns a tensor of size [batch]. There are more efficient ways to compute pairwise losses using broadcasting.

static
reorder_source
(source, n_src, min_loss_idx)[source]¶ Reorder sources according to the best permutation.
Parameters:  source (torch.Tensor) – Tensor of shape [batch, n_src, time]
 n_src (int) – Number of sources.
 min_loss_idx (torch.LongTensor) – Tensor of shape [batch], each item is in [0, n_src!).
Returns: torch.Tensor
– Reordered sources of shape [batch, n_src, time].MIT Copyright (c) 2018 Kaituo XU. See Original code and License.
Available loss functions¶
PITLossWrapper
supports three types of loss function. For “easy” losses,
we implement the three types (pairwise point, singlesource loss and multisource loss).
For others, we only implement the singlesource loss which can be aggregated
into both PIT and nonPIT training.
MSE¶

asteroid.losses.mse.
PairwiseMSE
(*args, **kwargs)[source]¶ Measure pairwise mean square error on a batch.
 Shape:
 est_targets (
torch.Tensor
): Expected shape [batch, nsrc, *].  The batch of target estimates.
 targets (
torch.Tensor
): Expected shape [batch, nsrc, *].  The batch of training targets
 est_targets (
Returns: torch.Tensor
– with shape [batch, nsrc, nsrc]Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(PairwiseMSE(), pit_from='pairwise') >>> loss = loss_func(est_targets, targets)

asteroid.losses.mse.
SingleSrcMSE
(*args, **kwargs)[source]¶ Measure mean square error on a batch. Supports both tensors with and without source axis.
 Shape:
 est_targets (
torch.Tensor
): Expected shape [batch, *].  The batch of target estimates.
 targets (
torch.Tensor
): Expected shape [batch, *].  The batch of training targets.
 est_targets (
Returns: torch.Tensor
– with shape [batch]Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'. >>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)

asteroid.losses.mse.
MultiSrcMSE
(*args, **kwargs)¶ Measure mean square error on a batch. Supports both tensors with and without source axis.
 Shape:
 est_targets (
torch.Tensor
): Expected shape [batch, *].  The batch of target estimates.
 targets (
torch.Tensor
): Expected shape [batch, *].  The batch of training targets.
 est_targets (
Returns: torch.Tensor
– with shape [batch]Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'. >>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
SDR¶

asteroid.losses.sdr.
PairwiseNegSDR
(*args, **kwargs)[source]¶ Base class for pairwise negative SISDR, SDSDR and SNR on a batch.
Parameters:  Shape:
 est_targets (
torch.Tensor
): Expected shape  [batch, n_src, time]. Batch of target estimates.
 targets (
torch.Tensor
): Expected shape  [batch, n_src, time]. Batch of training targets.
 est_targets (
Returns: torch.Tensor
– with shape [batch, n_src, n_src]. Pairwise losses.Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(PairwiseNegSDR("sisdr"), >>> pit_from='pairwise') >>> loss = loss_func(est_targets, targets)
References
[1] Le Roux, Jonathan, et al. “SDR halfbaked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

asteroid.losses.sdr.
SingleSrcNegSDR
(*args, **kwargs)[source]¶ Base class for singlesource negative SISDR, SDSDR and SNR.
Parameters:  sdr_type (string) – choose between “snr” for plain SNR, “sisdr” for SISDR and “sdsdr” for SDSDR [1].
 zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.
 take_log (bool, optional) – by default the log10 of sdr is returned.
 reduction (string, optional) – Specifies the reduction to apply to the output:
  'mean'. 'none' ('none') – no reduction will be applied,
 'mean' – the sum of the output will be divided by the number of
 in the output. (elements) –
 Shape:
 est_targets (
torch.Tensor
): Expected shape [batch, time].  Batch of target estimates.
 targets (
torch.Tensor
): Expected shape [batch, time].  Batch of training targets.
 est_targets (
Returns: torch.Tensor
– with shape [batch] if reduction=’none’ else
 [] scalar if reduction=’mean’.
Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(SingleSrcNegSDR("sisdr"), >>> pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
References
[1] Le Roux, Jonathan, et al. “SDR halfbaked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

asteroid.losses.sdr.
MultiSrcNegSDR
(*args, **kwargs)[source]¶ Base class for computing negative SISDR, SDSDR and SNR for a given permutation of source and their estimates.
Parameters:  Shape:
 est_targets (
torch.Tensor
): Expected shape [batch, time].  Batch of target estimates.
 targets (
torch.Tensor
): Expected shape [batch, time].  Batch of training targets.
 est_targets (
Returns: torch.Tensor
– with shape [batch] if reduction=’none’ else
 [] scalar if reduction=’mean’.
Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(MultiSrcNegSDR("sisdr"), >>> pit_from='perm_avg') >>> loss = loss_func(est_targets, targets)
References
[1] Le Roux, Jonathan, et al. “SDR halfbaked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.
PMSQE¶

asteroid.losses.pmsqe.
SingleSrcPMSQE
(*args, **kwargs)[source]¶ Computes the Perceptual Metric for Speech Quality Evaluation (PMSQE) as described in [1]. This version is only designed for 16 kHz (512 length DFT). Adaptation to 8 kHz could be done by changing the parameters of the class (see Tensorflow implementation). The SLL, frequency and gain equalization are applied in each sequence independently.
Parameters:  window_name (str) – Select the used window function for the correct factor to be applied. Defaults to sqrt hanning window. Among [‘rect’, ‘hann’, ‘sqrt_hann’, ‘hamming’, ‘flatTop’].
 window_weight (float, optional) – Correction to the window factor applied.
 bark_eq (bool, optional) – Whether to apply bark equalization.
 gain_eq (bool, optional) – Whether to apply gain equalization.
 sample_rate (int) – Sample rate of the input audio.
References
[1] J.M.Martin, A.M.Gomez, J.A.Gonzalez, A.M.Peinado ‘A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality’, IEEE Signal Processing Letters, 2018. Implemented by Juan M. Martin. Contact: mdjuamart@ugr.es Copyright 2019: University of Granada, Signal Processing, Multimedia Transmission and Speech/Audio Technologies (SigMAT) Group.
Note
Inspired on the Perceptual Evaluation of the Speech Quality (PESQ) algorithm, this function consists of two regularization factors : the symmetrical and asymmetrical distortion in the loudness domain.
Examples
>>> import torch >>> from asteroid.filterbanks import STFTFB, Encoder, transforms >>> from asteroid.losses import PITLossWrapper, SingleSrcPMSQE >>> stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)) >>> # Usage by itself >>> ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000) >>> ref_spec = transforms.take_mag(stft(ref)) >>> est_spec = transforms.take_mag(stft(est)) >>> loss_func = SingleSrcPMSQE() >>> loss_value = loss_func(est_spec, ref_spec) >>> # Usage with PITLossWrapper >>> loss_func = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt') >>> ref, est = torch.randn(2, 3, 16000), torch.randn(2, 3, 16000) >>> ref_spec = transforms.take_mag(stft(ref)) >>> est_spec = transforms.take_mag(stft(est)) >>> loss_value = loss_func(ref_spec, est_spec)
STOI¶

asteroid.losses.stoi.
NegSTOILoss
(*args, **kwargs)[source]¶  Negated Short Term Objective Intelligibility (STOI) metric, to be used
 as a loss function. Inspired from [1, 2, 3] but not exactly the same : cannot be used as the STOI metric directly (use pystoi instead). See Notes.
Parameters:  Shapes:
 (time,) –> (1, ) (batch, time) –> (batch, ) (batch, n_src, time) –> (batch, n_src)
Returns: torch.Tensor of shape (batch, *, ), only the time dimension has been reduced. Note
In the NumPy version, some kind of simple VAD was used to remove the silent frames before chunking the signal into shortterm envelope vectors. We don’t do the same here because removing frames in a batch is cumbersome and inefficient. If use_vad is set to True, instead we detect the silent frames and keep a mask tensor. At the end, the normalized correlation of shortterm envelope vectors is masked using this mask (unfolded) and the mean is computed taking the mask values into account.
Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(NegSTOILoss(sample_rate=8000), pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
 References
 [1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen ‘A ShortTime
 Objective Intelligibility Measure for TimeFrequency Weighted Noisy Speech’, ICASSP 2010, Texas, Dallas.
 [2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen ‘An Algorithm for
 Intelligibility Prediction of TimeFrequency Weighted Noisy Speech’, IEEE Transactions on Audio, Speech, and Language Processing, 2011.
 [3] Jesper Jensen and Cees H. Taal, ‘An Algorithm for Predicting the
 Intelligibility of Speech Masked by Modulated Noise Maskers’, IEEE Transactions on Audio, Speech and Language Processing, 2016.
MultiScale Spectral Loss¶

asteroid.losses.multi_scale_spectral.
SingleSrcMultiScaleSpectral
(*args, **kwargs)[source]¶ Measure multiscale spectral loss as described in [1]
Parameters:  Shape:
 est_targets (
torch.Tensor
): Expected shape [batch, time].  Batch of target estimates.
 targets (
torch.Tensor
): Expected shape [batch, time].  Batch of training targets.
alpha (float) : Weighting factor for the log term
 est_targets (
Returns: torch.Tensor
– with shape [batch]Examples
>>> import torch >>> targets = torch.randn(10, 32000) >>> est_targets = torch.randn(10, 32000) >>> # Using it by itself on a pair of source/estimate >>> loss_func = SingleSrcMultiScaleSpectral() >>> loss = loss_func(est_targets, targets)
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # Using it with PITLossWrapper with sets of source/estimates >>> loss_func = PITLossWrapper(SingleSrcMultiScaleSpectral(), >>> pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
References
[1] Jesse Engel and Lamtharn (Hanoi) Hantrakul and Chenjie Gu and Adam Roberts DDSP: Differentiable Digital Signal Processing International Conference on Learning Representations ICLR 2020 $
Deep clustering (Affinity) loss¶

asteroid.losses.cluster.
deep_clustering_loss
(embedding, tgt_index, binary_mask=None)[source]¶ Compute the deep clustering loss defined in [1].
Parameters:  embedding (torch.Tensor) – Estimated embeddings. Expected shape (batch, frequency x frame, embedding_dim)
 tgt_index (torch.Tensor) – Dominating source index in each TF bin. Expected shape: [batch, frequency, frame]
 binary_mask (torch.Tensor) – VAD in TF plane. Bool or Float. See asteroid.filterbanks.transforms.ebased_vad.
Returns: torch.Tensor. Deep clustering loss for every batch sample.
 Examples
>>> import torch >>> from asteroid.losses.cluster import deep_clustering_loss >>> spk_cnt = 3 >>> embedding = torch.randn(10, 5*400, 20) >>> targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt) >>> loss = deep_clustering_loss(embedding, targets)
 Reference
 [1] ZhongQiu Wang, Jonathan Le Roux, John R. Hershey
 “ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING”
Note
Be careful in viewing the embedding tensors. The target indices tgt_index are of shape (batch, freq, frames). Even if the embedding is of shape (batch, freq*frames, emb), the underlying view should be (batch, freq, frames, emb) and not (batch, frames, freq, emb).
Computing metrics¶

asteroid.metrics.
get_metrics
(mix, clean, estimate, sample_rate=16000, metrics_list='all', average=True, compute_permutation=False)[source]¶ Get speech separation/enhancement metrics from mix/clean/estimate.
Parameters:  mix (np.array) – ‘Shape(D, N)’ or ‘Shape(N, )’.
 clean (np.array) – ‘Shape(K_source, N)’ or ‘Shape(N, )’.
 estimate (np.array) – ‘Shape(K_target, N)’ or ‘Shape(N, )’.
 sample_rate (int) – sampling rate of the audio clips.
 metrics_list (Union [str, list]) – List of metrics to compute. Defaults to ‘all’ ([‘si_sdr’, ‘sdr’, ‘sir’, ‘sar’, ‘stoi’, ‘pesq’]).
 average (bool) – Return dict([float]) if True, else dict([array]).
 compute_permutation (bool) – Whether to compute the permutation on estimate sources for the output metrics (default False)
Returns: dict –
 Dictionary with all requested metrics, with ‘input_’ prefix
for metrics at the input (mixture against clean), no prefix at the output (estimate against clean). Output format depends on average.
Examples
>>> import numpy as np >>> import pprint >>> from asteroid.metrics import get_metrics >>> mix = np.random.randn(1, 16000) >>> clean = np.random.randn(2, 16000) >>> est = np.random.randn(2, 16000) >>> metrics_dict = get_metrics(mix, clean, est, sample_rate=8000, >>> metrics_list='all') >>> pprint.pprint(metrics_dict) {'input_pesq': 1.924380898475647, 'input_sar': 11.67667585294225, 'input_sdr': 14.88667106190552, 'input_si_sdr': 52.43849784881705, 'input_sir': 0.10419427290163795, 'input_stoi': 0.015112115177091223, 'pesq': 1.7713886499404907, 'sar': 11.610963379923195, 'sdr': 14.527246041125844, 'si_sdr': 46.26557128489802, 'sir': 0.4799929272243427, 'stoi': 0.022023073540350643}