Source code for asteroid.losses.cluster

import torch

[docs]def deep_clustering_loss(embedding, tgt_index, binary_mask=None): """ Compute the deep clustering loss defined in [1]. Args: embedding (torch.Tensor): Estimated embeddings. Expected shape (batch, frequency x frame, embedding_dim) tgt_index (torch.Tensor): Dominating source index in each TF bin. Expected shape: [batch, frequency, frame] binary_mask (torch.Tensor): VAD in TF plane. Bool or Float. See asteroid.filterbanks.transforms.ebased_vad. Returns: `torch.Tensor`. Deep clustering loss for every batch sample. Examples >>> import torch >>> from asteroid.losses.cluster import deep_clustering_loss >>> spk_cnt = 3 >>> embedding = torch.randn(10, 5*400, 20) >>> targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt) >>> loss = deep_clustering_loss(embedding, targets) Reference [1] Zhong-Qiu Wang, Jonathan Le Roux, John R. Hershey "ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING" .. note:: Be careful in viewing the embedding tensors. The target indices `tgt_index` are of shape (batch, freq, frames). Even if the embedding is of shape (batch, freq*frames, emb), the underlying view should be (batch, freq, frames, emb) and not (batch, frames, freq, emb). """ spk_cnt = len(tgt_index.unique()) batch, bins, frames = tgt_index.shape if binary_mask is None: binary_mask = torch.ones(batch, bins * frames, 1) binary_mask = binary_mask.float() if len(binary_mask.shape) == 3: binary_mask = binary_mask.view(batch, bins * frames, 1) # If boolean mask, make it float. binary_mask = # Fill in one-hot vector for each TF bin tgt_embedding = torch.zeros(batch, bins * frames, spk_cnt, device=tgt_index.device) tgt_embedding.scatter_(2, tgt_index.view(batch, bins * frames, 1), 1) # Compute VAD-weighted DC loss tgt_embedding = tgt_embedding * binary_mask embedding = embedding * binary_mask est_proj = torch.einsum("ijk,ijl->ikl", embedding, embedding) true_proj = torch.einsum("ijk,ijl->ikl", tgt_embedding, tgt_embedding) true_est_proj = torch.einsum("ijk,ijl->ikl", embedding, tgt_embedding) # Equation (1) in [1] cost = batch_matrix_norm(est_proj) + batch_matrix_norm(true_proj) cost = cost - 2 * batch_matrix_norm(true_est_proj) # Divide by number of active bins, for each element in batch return cost / torch.sum(binary_mask, dim=[1, 2])
[docs]def batch_matrix_norm(matrix, norm_order=2): """ Normalize a matrix according to `norm_order` Args: matrix (torch.Tensor): Expected shape [batch, *] norm_order (int): Norm order. Returns: torch.Tensor, normed matrix of shape [batch] """ keep_batch = list(range(1, matrix.ndim)) return torch.norm(matrix, p=norm_order, dim=keep_batch) ** norm_order
