asteroid.losses.cluster module¶
-
asteroid.losses.cluster.
batch_matrix_norm
(matrix, norm_order=2)[source]¶ Normalize a matrix according to norm_order
Parameters: - matrix (torch.Tensor) – Expected shape [batch, *]
- norm_order (int) – Norm order.
Returns: torch.Tensor, normed matrix of shape [batch]
-
asteroid.losses.cluster.
deep_clustering_loss
(embedding, tgt_index, binary_mask=None)[source]¶ Compute the deep clustering loss defined in [1].
Parameters: - embedding (torch.Tensor) – Estimated embeddings. Expected shape (batch, frequency x frame, embedding_dim)
- tgt_index (torch.Tensor) – Dominating source index in each TF bin. Expected shape: [batch, frequency, frame]
- binary_mask (torch.Tensor) – VAD in TF plane. Bool or Float. See asteroid.filterbanks.transforms.ebased_vad.
Returns: torch.Tensor. Deep clustering loss for every batch sample.
- Examples
>>> import torch >>> from asteroid.losses.cluster import deep_clustering_loss >>> spk_cnt = 3 >>> embedding = torch.randn(10, 5*400, 20) >>> targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt) >>> loss = deep_clustering_loss(embedding, targets)
- Reference
- [1] Zhong-Qiu Wang, Jonathan Le Roux, John R. Hershey
- “ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING”
Note
Be careful in viewing the embedding tensors. The target indices tgt_index are of shape (batch, freq, frames). Even if the embedding is of shape (batch, freq*frames, emb), the underlying view should be (batch, freq, frames, emb) and not (batch, frames, freq, emb).