Shortcuts

Source code for asteroid.losses.multi_scale_spectral

import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from asteroid.filterbanks import STFTFB, Encoder
from asteroid.filterbanks.transforms import take_mag

EPS = 1e-8


[docs]class SingleSrcMultiScaleSpectral(_Loss): """ Measure multi-scale spectral loss as described in [1] Args: n_filters (list): list containing the number of filter desired for each STFT windows_size (list): list containing the size of the window desired for each STFT hops_size (list): list containing the size of the hop desired for each STFT Shape: est_targets (:class:`torch.Tensor`): Expected shape [batch, time]. Batch of target estimates. targets (:class:`torch.Tensor`): Expected shape [batch, time]. Batch of training targets. alpha (float) : Weighting factor for the log term Returns: :class:`torch.Tensor`: with shape [batch] Examples: >>> import torch >>> targets = torch.randn(10, 32000) >>> est_targets = torch.randn(10, 32000) >>> # Using it by itself on a pair of source/estimate >>> loss_func = SingleSrcMultiScaleSpectral() >>> loss = loss_func(est_targets, targets) >>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # Using it with PITLossWrapper with sets of source/estimates >>> loss_func = PITLossWrapper(SingleSrcMultiScaleSpectral(), >>> pit_from='pw_pt') >>> loss = loss_func(est_targets, targets) References: [1] Jesse Engel and Lamtharn (Hanoi) Hantrakul and Chenjie Gu and Adam Roberts DDSP: Differentiable Digital Signal Processing International Conference on Learning Representations ICLR 2020 $ """ def __init__(self, n_filters=None, windows_size=None, hops_size=None, alpha=1.0): super().__init__() if windows_size is None: windows_size = [2048, 1024, 512, 256, 128, 64, 32] if n_filters is None: n_filters = [2048, 1024, 512, 256, 128, 64, 32] if hops_size is None: hops_size = [1024, 512, 256, 128, 64, 32, 16] self.windows_size = windows_size self.n_filters = n_filters self.hops_size = hops_size self.alpha = alpha self.encoders = nn.ModuleList( Encoder(STFTFB(n_filters[i], windows_size[i], hops_size[i])) for i in range(len(self.n_filters)) )
[docs] def forward(self, est_target, target): batch_size = est_target.shape[0] est_target = est_target.unsqueeze(1) target = target.unsqueeze(1) loss = torch.zeros(batch_size, device=est_target.device) for encoder in self.encoders: loss += self.compute_spectral_loss(encoder, est_target, target) return loss
[docs] def compute_spectral_loss(self, encoder, est_target, target): batch_size = est_target.shape[0] spect_est_target = take_mag(encoder(est_target)).view(batch_size, -1) spect_target = take_mag(encoder(target)).view(batch_size, -1) linear_loss = self.norm1(spect_est_target - spect_target) log_loss = self.norm1(torch.log(spect_est_target + EPS) - torch.log(spect_target + EPS)) return linear_loss + self.alpha * log_loss
[docs] @staticmethod def norm1(a): return torch.norm(a, p=1, dim=1)
Read the Docs v: v0.3.3
Versions
latest
stable
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.