asteroid.losses.cluster module¶
-
asteroid.losses.cluster.
deep_clustering_loss
(embedding, tgt_index, binary_mask=None)[source]¶ Compute the deep clustering loss defined in [1].
Parameters: - embedding (torch.Tensor) – Estimated embeddings. Expected shape \((batch, frequency * frame, embedding\_dim)\).
- tgt_index (torch.Tensor) – Dominating source index in each TF bin. Expected shape: \((batch, frequency, frame)\).
- binary_mask (torch.Tensor) – VAD in TF plane. Bool or Float. See asteroid.dsp.vad.ebased_vad.
Returns: torch.Tensor. Deep clustering loss for every batch sample.
- Examples
>>> import torch >>> from asteroid.losses.cluster import deep_clustering_loss >>> spk_cnt = 3 >>> embedding = torch.randn(10, 5*400, 20) >>> targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt) >>> loss = deep_clustering_loss(embedding, targets)
- Reference
- [1] Zhong-Qiu Wang, Jonathan Le Roux, John R. Hershey “ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING”
Note
Be careful in viewing the embedding tensors. The target indices
tgt_index
are of shape \((batch, freq, frames)\). Even if the embedding is of shape \((batch, freq * frames, emb)\), the underlying view should be \((batch, freq, frames, emb)\) and not \((batch, frames, freq, emb)\).
-
asteroid.losses.cluster.
batch_matrix_norm
(matrix, norm_order=2)[source]¶ Normalize a matrix according to norm_order
Parameters: - matrix (torch.Tensor) – Expected shape [batch, *]
- norm_order (int) – Norm order.
Returns: torch.Tensor, normed matrix of shape [batch]