Shortcuts

asteroid.losses.cluster module

asteroid.losses.cluster.deep_clustering_loss(embedding, tgt_index, binary_mask=None)[source]

Compute the deep clustering loss defined in [1].

Parameters:
  • embedding (torch.Tensor) – Estimated embeddings. Expected shape \((batch, frequency * frame, embedding\_dim)\).
  • tgt_index (torch.Tensor) – Dominating source index in each TF bin. Expected shape: \((batch, frequency, frame)\).
  • binary_mask (torch.Tensor) – VAD in TF plane. Bool or Float. See asteroid.dsp.vad.ebased_vad.
Returns:

torch.Tensor. Deep clustering loss for every batch sample.

Examples
>>> import torch
>>> from asteroid.losses.cluster import deep_clustering_loss
>>> spk_cnt = 3
>>> embedding = torch.randn(10, 5*400, 20)
>>> targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt)
>>> loss = deep_clustering_loss(embedding, targets)
Reference
[1] Zhong-Qiu Wang, Jonathan Le Roux, John R. Hershey “ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING”

Note

Be careful in viewing the embedding tensors. The target indices tgt_index are of shape \((batch, freq, frames)\). Even if the embedding is of shape \((batch, freq * frames, emb)\), the underlying view should be \((batch, freq, frames, emb)\) and not \((batch, frames, freq, emb)\).

asteroid.losses.cluster.batch_matrix_norm(matrix, norm_order=2)[source]

Normalize a matrix according to norm_order

Parameters:
  • matrix (torch.Tensor) – Expected shape [batch, *]
  • norm_order (int) – Norm order.
Returns:

torch.Tensor, normed matrix of shape [batch]

Read the Docs v: v0.4.4
Versions
latest
stable
v0.4.4
v0.4.3
v0.4.2
v0.4.1
v0.4.0
v0.3.5_b
v0.3.4
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.