Source code for asteroid.filterbanks.analytic_free_fb

import torch
import torch.nn as nn
import numpy as np
from .enc_dec import Filterbank

[docs]class AnalyticFreeFB(Filterbank): """ Free analytic (fully learned with analycity constraints) filterbank. For more details, see [1]. Args: n_filters (int): Number of filters. Half of `n_filters` will have parameters, the other half will be the hilbert transforms. `n_filters` should be even. kernel_size (int): Length of the filters. stride (int, optional): Stride of the convolution. If None (default), set to ``kernel_size // 2``. Attributes: n_feats_out (int): Number of output filters. References: [1] : "Filterbank design for end-to-end speech separation". Submitted to ICASSP 2020. Manuel Pariente, Samuele Cornell, Antoine Deleforge, Emmanuel Vincent. """ def __init__(self, n_filters, kernel_size, stride=None, **kwargs): super(AnalyticFreeFB, self).__init__(n_filters, kernel_size, stride=stride) self.cutoff = int(n_filters // 2) self.n_feats_out = 2 * self.cutoff if n_filters % 2 != 0: print( "If the number of filters `n_filters` is odd, the " "output size of the layer will be `n_filters - 1`." ) self._filters = nn.Parameter(torch.ones(n_filters // 2, 1, kernel_size), requires_grad=True) for p in self.parameters(): nn.init.xavier_normal_(p, gain=1.0 / np.sqrt(2.0)) @property def filters(self): ft_f = torch.rfft(self._filters, 1, normalized=True) hft_f = torch.stack([ft_f[:, :, :, 1], -ft_f[:, :, :, 0]], dim=-1) hft_f = torch.irfft(hft_f, 1, normalized=True, signal_sizes=(self.kernel_size,)) return[self._filters, hft_f], dim=0)