Shortcuts

asteroid.losses.cluster module

asteroid.losses.cluster.batch_matrix_norm(matrix, norm_order=2)[source]

Normalize a matrix according to norm_order

Parameters:
  • matrix (torch.Tensor) – Expected shape [batch, *]
  • norm_order (int) – Norm order.
Returns:

torch.Tensor, normed matrix of shape [batch]

asteroid.losses.cluster.deep_clustering_loss(embedding, tgt_index, binary_mask=None)[source]

Compute the deep clustering loss defined in [1].

Parameters:
  • embedding (torch.Tensor) – Estimated embeddings. Expected shape (batch, frequency x frame, embedding_dim)
  • tgt_index (torch.Tensor) – Dominating source index in each TF bin. Expected shape: [batch, frequency, frame]
  • binary_mask (torch.Tensor) – VAD in TF plane. Bool or Float. See asteroid.filterbanks.transforms.ebased_vad.
Returns:

torch.Tensor. Deep clustering loss for every batch sample.

Examples
>>> import torch
>>> from asteroid.losses.cluster import deep_clustering_loss
>>> spk_cnt = 3
>>> embedding = torch.randn(10, 5*400, 20)
>>> targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt)
>>> loss = deep_clustering_loss(embedding, targets)
Reference
[1] Zhong-Qiu Wang, Jonathan Le Roux, John R. Hershey
“ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING”

Note

Be careful in viewing the embedding tensors. The target indices tgt_index are of shape (batch, freq, frames). Even if the embedding is of shape (batch, freq*frames, emb), the underlying view should be (batch, freq, frames, emb) and not (batch, frames, freq, emb).

Read the Docs v: v0.3.3
Versions
latest
stable
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.