Losses & Metrics

class asteroid.losses.PITLossWrapper(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Permutation invariant loss wrapper.

Parameters:
  • loss_func – function with signature (targets, est_targets, **kwargs).
  • pit_from (str) –

    Determines how PIT is applied.

    • 'pw_mtx' (pairwise matrix): loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \([batch, i, j]\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\)
    • 'pw_pt' (pairwise point): loss_func computes the loss for a batch of single source and single estimates (tensors won’t have the source axis). Output shape : \((batch)\). See get_pw_losses().
    • ``’perm_avg’``(permutation average): loss_func computes the average loss for a given permutations of the sources and estimates. Output shape : \((batch)\). See best_perm_from_perm_avg_loss().

    In terms of efficiency, 'perm_avg' is the least efficicient.

  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) –> (B, n_src!). perm_reduce can receive **kwargs during forward using the reduce_kwargs argument (dict). If those argument are static, consider defining a small function or using functools.partial. Only used in ‘pw_mtx’ and ‘pw_pt’ pit_from modes.

For each of these modes, the best permutation and reordering will be automatically computed.

Examples

>>> import torch
>>> from asteroid.losses import pairwise_neg_sisdr
>>> sources = torch.randn(10, 3, 16000)
>>> est_sources = torch.randn(10, 3, 16000)
>>> # Compute PIT loss based on pairwise losses
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
>>> loss_val = loss_func(est_sources, sources)
>>>
>>> # Using reduce
>>> def reduce(perm_loss, src):
>>>     weighted = perm_loss * src.norm(dim=-1, keepdim=True)
>>>     return torch.mean(weighted, dim=-1)
>>>
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx',
>>>                            perm_reduce=reduce)
>>> reduce_kwargs = {'src': sources}
>>> loss_val = loss_func(est_sources, sources,
>>>                      reduce_kwargs=reduce_kwargs)
static best_perm_from_perm_avg_loss(loss_func, est_targets, targets, **kwargs)[source]

Find best permutation from loss function with source axis.

Parameters:
  • loss_func – function with signature (targets, est_targets, **kwargs) The loss function batch losses from.
  • est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
  • targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets.
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

tupletorch.Tensor: The loss corresponding to the best permutation of size (batch,).

torch.LongTensor: The indexes of the best permutations.

static find_best_perm(pair_wise_losses, n_src, perm_reduce=None, **kwargs)[source]

Find the best permutation, given the pair-wise losses.

Parameters:
  • pair_wise_losses (torch.Tensor) – Tensor of shape [batch, n_src, n_src]. Pairwise losses.
  • n_src (int) – Number of sources.
  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) –> (B, n_src!)
  • **kwargs – additional keyword argument that will be passed to the permutation reduce function.
Returns:

tupletorch.Tensor: The loss corresponding to the best permutation of size (batch,).

torch.LongTensor: The indexes of the best permutations.

MIT Copyright (c) 2018 Kaituo XU. See Original code and License.

forward(est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs)[source]

Find the best permutation and return the loss.

Parameters:
  • est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
  • targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets
  • return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
  • reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

  • Best permutation loss for each batch sample, average over
    the batch. torch.Tensor(loss_value)
  • The reordered targets estimates if return_est is True.
    torch.Tensor of shape [batch, nsrc, *].

static get_pw_losses(loss_func, est_targets, targets, **kwargs)[source]

Get pair-wise losses between the training targets and its estimate for a given loss function.

Parameters:
  • loss_func – function with signature (targets, est_targets, **kwargs) The loss function to get pair-wise losses from.
  • est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
  • targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets.
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

torch.Tensor or size [batch, nsrc, nsrc], losses computed for all permutations of the targets and est_targets.

This function can be called on a loss function which returns a tensor of size [batch]. There are more efficient ways to compute pair-wise losses using broadcasting.

static reorder_source(source, n_src, min_loss_idx)[source]

Reorder sources according to the best permutation.

Parameters:
  • source (torch.Tensor) – Tensor of shape [batch, n_src, time]
  • n_src (int) – Number of sources.
  • min_loss_idx (torch.LongTensor) – Tensor of shape [batch], each item is in [0, n_src!).
Returns:

torch.Tensor – Reordered sources of shape [batch, n_src, time].

MIT Copyright (c) 2018 Kaituo XU. See Original code and License.

class asteroid.losses.SingleSrcPMSQE(window_name='sqrt_hann', window_weight=1.0, bark_eq=True, gain_eq=True, sample_rate=16000)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Computes the Perceptual Metric for Speech Quality Evaluation (PMSQE) as described in [1]. This version is only designed for 16 kHz (512 length DFT). Adaptation to 8 kHz could be done by changing the parameters of the class (see Tensorflow implementation). The SLL, frequency and gain equalization are applied in each sequence independently.

Parameters:
  • window_name (str) – Select the used window function for the correct factor to be applied. Defaults to sqrt hanning window. Among [‘rect’, ‘hann’, ‘sqrt_hann’, ‘hamming’, ‘flatTop’].
  • window_weight (float, optional) – Correction to the window factor applied.
  • bark_eq (bool, optional) – Whether to apply bark equalization.
  • gain_eq (bool, optional) – Whether to apply gain equalization.
  • sample_rate (int) – Sample rate of the input audio.

References

[1] J.M.Martin, A.M.Gomez, J.A.Gonzalez, A.M.Peinado ‘A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality’, IEEE Signal Processing Letters, 2018. Implemented by Juan M. Martin. Contact: mdjuamart@ugr.es Copyright 2019: University of Granada, Signal Processing, Multimedia Transmission and Speech/Audio Technologies (SigMAT) Group.

Note

Inspired on the Perceptual Evaluation of the Speech Quality (PESQ) algorithm, this function consists of two regularization factors : the symmetrical and asymmetrical distortion in the loudness domain.

Examples

>>> import torch
>>> from asteroid.filterbanks import STFTFB, Encoder, transforms
>>> from asteroid.losses import PITLossWrapper, SingleSrcPMSQE
>>> stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
>>> # Usage by itself
>>> ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000)
>>> ref_spec = transforms.take_mag(stft(ref))
>>> est_spec = transforms.take_mag(stft(est))
>>> loss_func = SingleSrcPMSQE()
>>> loss_value = loss_func(est_spec, ref_spec)
>>> # Usage with PITLossWrapper
>>> loss_func = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt')
>>> ref, est = torch.randn(2, 3, 16000), torch.randn(2, 3, 16000)
>>> ref_spec = transforms.take_mag(stft(ref))
>>> est_spec = transforms.take_mag(stft(est))
>>> loss_value = loss_func(ref_spec, est_spec)
bark_freq_equalization(ref_bark_spectra, deg_bark_spectra)[source]

This version is applied in the degraded directly.

forward(est_targets, targets, pad_mask=None)[source]
Args
est_targets (torch.Tensor): Dimensions (B, T, F).
Padded degraded power spectrum in time-frequency domain.
targets (torch.Tensor): Dimensions (B, T, F).
Zero-Padded reference power spectrum in time-frequency domain.
pad_mask (torch.Tensor, optional): Dimensions (B, T, 1). Mask
to indicate the padding frames. Defaults to all ones.
Dimensions
B: Number of sequences in the batch. T: Number of time frames. F: Number of frequency bins.
Returns
torch.tensor of shape (B, ), wD + 0.309 * wDA
Notes
Dimensions (B, F, T) are also supported by SingleSrcPMSQE but are less efficient because input tensors are transposed (not inplace).

Examples

static get_correction_factor(window_name)[source]

Returns the power correction factor depending on the window.

asteroid.losses.SingleSrcNegSTOI

alias of asteroid.losses.stoi.NegSTOILoss

class asteroid.losses.PairwiseNegSDR(sdr_type, zero_mean=True, take_log=True)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Base class for pairwise negative SI-SDR, SD-SDR and SNR on a batch.

Parameters:
  • sdr_type (str) – choose between “snr” for plain SNR, “sisdr” for SI-SDR and “sdsdr” for SD-SDR [1].
  • zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.
  • take_log (bool, optional) – by default the log10 of sdr is returned.
Shape:
est_targets (torch.Tensor): Expected shape
[batch, n_src, time]. Batch of target estimates.
targets (torch.Tensor): Expected shape
[batch, n_src, time]. Batch of training targets.
Returns:torch.Tensor – with shape [batch, n_src, n_src]. Pairwise losses.

Examples

>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(PairwiseNegSDR("sisdr"),
>>>                            pit_from='pairwise')
>>> loss = loss_func(est_targets, targets)

References

[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

asteroid.losses.deep_clustering_loss(embedding, tgt_index, binary_mask=None)[source]

Compute the deep clustering loss defined in [1].

Parameters:
  • embedding (torch.Tensor) – Estimated embeddings. Expected shape (batch, frequency x frame, embedding_dim)
  • tgt_index (torch.Tensor) – Dominating source index in each TF bin. Expected shape: [batch, frequency, frame]
  • binary_mask (torch.Tensor) – VAD in TF plane. Bool or Float. See asteroid.filterbanks.transforms.ebased_vad.
Returns:

torch.Tensor. Deep clustering loss for every batch sample.

Examples
>>> import torch
>>> from asteroid.losses.cluster import deep_clustering_loss
>>> spk_cnt = 3
>>> embedding = torch.randn(10, 5*400, 20)
>>> targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt)
>>> loss = deep_clustering_loss(embedding, targets)
Reference
[1] Zhong-Qiu Wang, Jonathan Le Roux, John R. Hershey
“ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING”

Note

Be careful in viewing the embedding tensors. The target indices tgt_index are of shape (batch, freq, frames). Even if the embedding is of shape (batch, freq*frames, emb), the underlying view should be (batch, freq, frames, emb) and not (batch, frames, freq, emb).

Permutation invariant training (PIT) made easy

class asteroid.losses.pit_wrapper.PITLossWrapper(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Permutation invariant loss wrapper.

Parameters:
  • loss_func – function with signature (targets, est_targets, **kwargs).
  • pit_from (str) –

    Determines how PIT is applied.

    • 'pw_mtx' (pairwise matrix): loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \([batch, i, j]\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\)
    • 'pw_pt' (pairwise point): loss_func computes the loss for a batch of single source and single estimates (tensors won’t have the source axis). Output shape : \((batch)\). See get_pw_losses().
    • ``’perm_avg’``(permutation average): loss_func computes the average loss for a given permutations of the sources and estimates. Output shape : \((batch)\). See best_perm_from_perm_avg_loss().

    In terms of efficiency, 'perm_avg' is the least efficicient.

  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) –> (B, n_src!). perm_reduce can receive **kwargs during forward using the reduce_kwargs argument (dict). If those argument are static, consider defining a small function or using functools.partial. Only used in ‘pw_mtx’ and ‘pw_pt’ pit_from modes.

For each of these modes, the best permutation and reordering will be automatically computed.

Examples

>>> import torch
>>> from asteroid.losses import pairwise_neg_sisdr
>>> sources = torch.randn(10, 3, 16000)
>>> est_sources = torch.randn(10, 3, 16000)
>>> # Compute PIT loss based on pairwise losses
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
>>> loss_val = loss_func(est_sources, sources)
>>>
>>> # Using reduce
>>> def reduce(perm_loss, src):
>>>     weighted = perm_loss * src.norm(dim=-1, keepdim=True)
>>>     return torch.mean(weighted, dim=-1)
>>>
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx',
>>>                            perm_reduce=reduce)
>>> reduce_kwargs = {'src': sources}
>>> loss_val = loss_func(est_sources, sources,
>>>                      reduce_kwargs=reduce_kwargs)
static best_perm_from_perm_avg_loss(loss_func, est_targets, targets, **kwargs)[source]

Find best permutation from loss function with source axis.

Parameters:
  • loss_func – function with signature (targets, est_targets, **kwargs) The loss function batch losses from.
  • est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
  • targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets.
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

tupletorch.Tensor: The loss corresponding to the best permutation of size (batch,).

torch.LongTensor: The indexes of the best permutations.

static find_best_perm(pair_wise_losses, n_src, perm_reduce=None, **kwargs)[source]

Find the best permutation, given the pair-wise losses.

Parameters:
  • pair_wise_losses (torch.Tensor) – Tensor of shape [batch, n_src, n_src]. Pairwise losses.
  • n_src (int) – Number of sources.
  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : (B, n_src!, n_src) –> (B, n_src!)
  • **kwargs – additional keyword argument that will be passed to the permutation reduce function.
Returns:

tupletorch.Tensor: The loss corresponding to the best permutation of size (batch,).

torch.LongTensor: The indexes of the best permutations.

MIT Copyright (c) 2018 Kaituo XU. See Original code and License.

forward(est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs)[source]

Find the best permutation and return the loss.

Parameters:
  • est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
  • targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets
  • return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
  • reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

  • Best permutation loss for each batch sample, average over
    the batch. torch.Tensor(loss_value)
  • The reordered targets estimates if return_est is True.
    torch.Tensor of shape [batch, nsrc, *].

static get_pw_losses(loss_func, est_targets, targets, **kwargs)[source]

Get pair-wise losses between the training targets and its estimate for a given loss function.

Parameters:
  • loss_func – function with signature (targets, est_targets, **kwargs) The loss function to get pair-wise losses from.
  • est_targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of target estimates.
  • targets – torch.Tensor. Expected shape [batch, nsrc, *]. The batch of training targets.
  • **kwargs – additional keyword argument that will be passed to the loss function.
Returns:

torch.Tensor or size [batch, nsrc, nsrc], losses computed for all permutations of the targets and est_targets.

This function can be called on a loss function which returns a tensor of size [batch]. There are more efficient ways to compute pair-wise losses using broadcasting.

static reorder_source(source, n_src, min_loss_idx)[source]

Reorder sources according to the best permutation.

Parameters:
  • source (torch.Tensor) – Tensor of shape [batch, n_src, time]
  • n_src (int) – Number of sources.
  • min_loss_idx (torch.LongTensor) – Tensor of shape [batch], each item is in [0, n_src!).
Returns:

torch.Tensor – Reordered sources of shape [batch, n_src, time].

MIT Copyright (c) 2018 Kaituo XU. See Original code and License.

Available loss functions

PITLossWrapper supports three types of loss function. For “easy” losses, we implement the three types (pairwise point, single-source loss and multi-source loss). For others, we only implement the single-source loss which can be aggregated into both PIT and nonPIT training.

MSE

asteroid.losses.mse.PairwiseMSE(*args, **kwargs)[source]

Measure pairwise mean square error on a batch.

Shape:
est_targets (torch.Tensor): Expected shape [batch, nsrc, *].
The batch of target estimates.
targets (torch.Tensor): Expected shape [batch, nsrc, *].
The batch of training targets
Returns:torch.Tensor – with shape [batch, nsrc, nsrc]

Examples

>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(PairwiseMSE(), pit_from='pairwise')
>>> loss = loss_func(est_targets, targets)
asteroid.losses.mse.SingleSrcMSE(*args, **kwargs)[source]

Measure mean square error on a batch. Supports both tensors with and without source axis.

Shape:
est_targets (torch.Tensor): Expected shape [batch, *].
The batch of target estimates.
targets (torch.Tensor): Expected shape [batch, *].
The batch of training targets.
Returns:torch.Tensor – with shape [batch]

Examples

>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'.
>>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)
asteroid.losses.mse.MultiSrcMSE(*args, **kwargs)

Measure mean square error on a batch. Supports both tensors with and without source axis.

Shape:
est_targets (torch.Tensor): Expected shape [batch, *].
The batch of target estimates.
targets (torch.Tensor): Expected shape [batch, *].
The batch of training targets.
Returns:torch.Tensor – with shape [batch]

Examples

>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'.
>>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)

SDR

asteroid.losses.sdr.PairwiseNegSDR(*args, **kwargs)[source]

Base class for pairwise negative SI-SDR, SD-SDR and SNR on a batch.

Parameters:
  • sdr_type (str) – choose between “snr” for plain SNR, “sisdr” for SI-SDR and “sdsdr” for SD-SDR [1].
  • zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.
  • take_log (bool, optional) – by default the log10 of sdr is returned.
Shape:
est_targets (torch.Tensor): Expected shape
[batch, n_src, time]. Batch of target estimates.
targets (torch.Tensor): Expected shape
[batch, n_src, time]. Batch of training targets.
Returns:torch.Tensor – with shape [batch, n_src, n_src]. Pairwise losses.

Examples

>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(PairwiseNegSDR("sisdr"),
>>>                            pit_from='pairwise')
>>> loss = loss_func(est_targets, targets)

References

[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

asteroid.losses.sdr.SingleSrcNegSDR(*args, **kwargs)[source]

Base class for single-source negative SI-SDR, SD-SDR and SNR.

Parameters:
  • sdr_type (string) – choose between “snr” for plain SNR, “sisdr” for SI-SDR and “sdsdr” for SD-SDR [1].
  • zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.
  • take_log (bool, optional) – by default the log10 of sdr is returned.
  • reduction (string, optional) – Specifies the reduction to apply to the output:
  • | 'mean'. 'none' ('none') – no reduction will be applied,
  • 'mean' – the sum of the output will be divided by the number of
  • in the output. (elements) –
Shape:
est_targets (torch.Tensor): Expected shape [batch, time].
Batch of target estimates.
targets (torch.Tensor): Expected shape [batch, time].
Batch of training targets.
Returns:torch.Tensor
with shape [batch] if reduction=’none’ else
[] scalar if reduction=’mean’.

Examples

>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(SingleSrcNegSDR("sisdr"),
>>>                            pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)

References

[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

asteroid.losses.sdr.MultiSrcNegSDR(*args, **kwargs)[source]

Base class for computing negative SI-SDR, SD-SDR and SNR for a given permutation of source and their estimates.

Parameters:
  • sdr_type (string) – choose between “snr” for plain SNR, “sisdr” for SI-SDR and “sdsdr” for SD-SDR [1].
  • zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.
  • take_log (bool, optional) – by default the log10 of sdr is returned.
Shape:
est_targets (torch.Tensor): Expected shape [batch, time].
Batch of target estimates.
targets (torch.Tensor): Expected shape [batch, time].
Batch of training targets.
Returns:torch.Tensor
with shape [batch] if reduction=’none’ else
[] scalar if reduction=’mean’.

Examples

>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(MultiSrcNegSDR("sisdr"),
>>>                            pit_from='perm_avg')
>>> loss = loss_func(est_targets, targets)

References

[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

PMSQE

asteroid.losses.pmsqe.SingleSrcPMSQE(*args, **kwargs)[source]

Computes the Perceptual Metric for Speech Quality Evaluation (PMSQE) as described in [1]. This version is only designed for 16 kHz (512 length DFT). Adaptation to 8 kHz could be done by changing the parameters of the class (see Tensorflow implementation). The SLL, frequency and gain equalization are applied in each sequence independently.

Parameters:
  • window_name (str) – Select the used window function for the correct factor to be applied. Defaults to sqrt hanning window. Among [‘rect’, ‘hann’, ‘sqrt_hann’, ‘hamming’, ‘flatTop’].
  • window_weight (float, optional) – Correction to the window factor applied.
  • bark_eq (bool, optional) – Whether to apply bark equalization.
  • gain_eq (bool, optional) – Whether to apply gain equalization.
  • sample_rate (int) – Sample rate of the input audio.

References

[1] J.M.Martin, A.M.Gomez, J.A.Gonzalez, A.M.Peinado ‘A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality’, IEEE Signal Processing Letters, 2018. Implemented by Juan M. Martin. Contact: mdjuamart@ugr.es Copyright 2019: University of Granada, Signal Processing, Multimedia Transmission and Speech/Audio Technologies (SigMAT) Group.

Note

Inspired on the Perceptual Evaluation of the Speech Quality (PESQ) algorithm, this function consists of two regularization factors : the symmetrical and asymmetrical distortion in the loudness domain.

Examples

>>> import torch
>>> from asteroid.filterbanks import STFTFB, Encoder, transforms
>>> from asteroid.losses import PITLossWrapper, SingleSrcPMSQE
>>> stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
>>> # Usage by itself
>>> ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000)
>>> ref_spec = transforms.take_mag(stft(ref))
>>> est_spec = transforms.take_mag(stft(est))
>>> loss_func = SingleSrcPMSQE()
>>> loss_value = loss_func(est_spec, ref_spec)
>>> # Usage with PITLossWrapper
>>> loss_func = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt')
>>> ref, est = torch.randn(2, 3, 16000), torch.randn(2, 3, 16000)
>>> ref_spec = transforms.take_mag(stft(ref))
>>> est_spec = transforms.take_mag(stft(est))
>>> loss_value = loss_func(ref_spec, est_spec)

STOI

asteroid.losses.stoi.NegSTOILoss(*args, **kwargs)[source]
Negated Short Term Objective Intelligibility (STOI) metric, to be used
as a loss function. Inspired from [1, 2, 3] but not exactly the same : cannot be used as the STOI metric directly (use pystoi instead). See Notes.
Parameters:
  • sample_rate (int) – sample rate of the audio files
  • use_vad (bool) – Whether to use simple VAD (see Notes)
  • extended (bool) – Whether to compute extended version [3].
Shapes:
(time,) –> (1, ) (batch, time) –> (batch, ) (batch, n_src, time) –> (batch, n_src)
Returns:torch.Tensor of shape (batch, *, ), only the time dimension has been reduced.

Note

In the NumPy version, some kind of simple VAD was used to remove the silent frames before chunking the signal into short-term envelope vectors. We don’t do the same here because removing frames in a batch is cumbersome and inefficient. If use_vad is set to True, instead we detect the silent frames and keep a mask tensor. At the end, the normalized correlation of short-term envelope vectors is masked using this mask (unfolded) and the mean is computed taking the mask values into account.

Examples

>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(NegSTOILoss(sample_rate=8000), pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)
References
[1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen ‘A Short-Time
Objective Intelligibility Measure for Time-Frequency Weighted Noisy Speech’, ICASSP 2010, Texas, Dallas.
[2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen ‘An Algorithm for
Intelligibility Prediction of Time-Frequency Weighted Noisy Speech’, IEEE Transactions on Audio, Speech, and Language Processing, 2011.
[3] Jesper Jensen and Cees H. Taal, ‘An Algorithm for Predicting the
Intelligibility of Speech Masked by Modulated Noise Maskers’, IEEE Transactions on Audio, Speech and Language Processing, 2016.

MultiScale Spectral Loss

asteroid.losses.multi_scale_spectral.SingleSrcMultiScaleSpectral(*args, **kwargs)[source]

Measure multi-scale spectral loss as described in [1]

Parameters:
  • n_filters (list) – list containing the number of filter desired for each STFT
  • windows_size (list) – list containing the size of the window desired for each STFT
  • hops_size (list) – list containing the size of the hop desired for each STFT
Shape:
est_targets (torch.Tensor): Expected shape [batch, time].
Batch of target estimates.
targets (torch.Tensor): Expected shape [batch, time].
Batch of training targets.

alpha (float) : Weighting factor for the log term

Returns:torch.Tensor – with shape [batch]

Examples

>>> import torch
>>> targets = torch.randn(10, 32000)
>>> est_targets = torch.randn(10, 32000)
>>> # Using it by itself on a pair of source/estimate
>>> loss_func = SingleSrcMultiScaleSpectral()
>>> loss = loss_func(est_targets, targets)
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> # Using it with PITLossWrapper with sets of source/estimates
>>> loss_func = PITLossWrapper(SingleSrcMultiScaleSpectral(),
>>>                            pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)

References

[1] Jesse Engel and Lamtharn (Hanoi) Hantrakul and Chenjie Gu and Adam Roberts DDSP: Differentiable Digital Signal Processing International Conference on Learning Representations ICLR 2020 $

Deep clustering (Affinity) loss

asteroid.losses.cluster.deep_clustering_loss(embedding, tgt_index, binary_mask=None)[source]

Compute the deep clustering loss defined in [1].

Parameters:
  • embedding (torch.Tensor) – Estimated embeddings. Expected shape (batch, frequency x frame, embedding_dim)
  • tgt_index (torch.Tensor) – Dominating source index in each TF bin. Expected shape: [batch, frequency, frame]
  • binary_mask (torch.Tensor) – VAD in TF plane. Bool or Float. See asteroid.filterbanks.transforms.ebased_vad.
Returns:

torch.Tensor. Deep clustering loss for every batch sample.

Examples
>>> import torch
>>> from asteroid.losses.cluster import deep_clustering_loss
>>> spk_cnt = 3
>>> embedding = torch.randn(10, 5*400, 20)
>>> targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt)
>>> loss = deep_clustering_loss(embedding, targets)
Reference
[1] Zhong-Qiu Wang, Jonathan Le Roux, John R. Hershey
“ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING”

Note

Be careful in viewing the embedding tensors. The target indices tgt_index are of shape (batch, freq, frames). Even if the embedding is of shape (batch, freq*frames, emb), the underlying view should be (batch, freq, frames, emb) and not (batch, frames, freq, emb).

Computing metrics

asteroid.metrics.get_metrics(mix, clean, estimate, sample_rate=16000, metrics_list='all', average=True, compute_permutation=False)[source]

Get speech separation/enhancement metrics from mix/clean/estimate.

Parameters:
  • mix (np.array) – ‘Shape(D, N)’ or ‘Shape(N, )’.
  • clean (np.array) – ‘Shape(K_source, N)’ or ‘Shape(N, )’.
  • estimate (np.array) – ‘Shape(K_target, N)’ or ‘Shape(N, )’.
  • sample_rate (int) – sampling rate of the audio clips.
  • metrics_list (Union [str, list]) – List of metrics to compute. Defaults to ‘all’ ([‘si_sdr’, ‘sdr’, ‘sir’, ‘sar’, ‘stoi’, ‘pesq’]).
  • average (bool) – Return dict([float]) if True, else dict([array]).
  • compute_permutation (bool) – Whether to compute the permutation on estimate sources for the output metrics (default False)
Returns:

dict

Dictionary with all requested metrics, with ‘input_’ prefix

for metrics at the input (mixture against clean), no prefix at the output (estimate against clean). Output format depends on average.

Examples

>>> import numpy as np
>>> import pprint
>>> from asteroid.metrics import get_metrics
>>> mix = np.random.randn(1, 16000)
>>> clean = np.random.randn(2, 16000)
>>> est = np.random.randn(2, 16000)
>>> metrics_dict = get_metrics(mix, clean, est, sample_rate=8000,
>>>                            metrics_list='all')
>>> pprint.pprint(metrics_dict)
{'input_pesq': 1.924380898475647,
 'input_sar': -11.67667585294225,
 'input_sdr': -14.88667106190552,
 'input_si_sdr': -52.43849784881705,
 'input_sir': -0.10419427290163795,
 'input_stoi': 0.015112115177091223,
 'pesq': 1.7713886499404907,
 'sar': -11.610963379923195,
 'sdr': -14.527246041125844,
 'si_sdr': -46.26557128489802,
 'sir': 0.4799929272243427,
 'stoi': 0.022023073540350643}