
Source code for asteroid.filterbanks.stft_fb

import torch
import numpy as np
from .enc_dec import Filterbank

[docs]class STFTFB(Filterbank): """STFT filterbank. Args: n_filters (int): Number of filters. Determines the length of the STFT filters before windowing. kernel_size (int): Length of the filters (i.e the window). stride (int, optional): Stride of the convolution (hop size). If None (default), set to ``kernel_size // 2``. window (:class:`numpy.ndarray`, optional): If None, defaults to ``np.sqrt(np.hanning())``. Attributes: n_feats_out (int): Number of output filters. """ def __init__(self, n_filters, kernel_size, stride=None, window=None, **kwargs): super(STFTFB, self).__init__(n_filters, kernel_size, stride=stride) assert n_filters >= kernel_size self.cutoff = int(n_filters / 2 + 1) self.n_feats_out = 2 * self.cutoff if window is None: self.window = np.hanning(kernel_size + 1)[:-1] ** 0.5 else: ws = window.size if not (ws == kernel_size): raise AssertionError( "Expected window of size {}." "Received window of size {} instead." "".format(kernel_size, ws) ) self.window = window # Create and normalize DFT filters (can be overcomplete) filters = np.fft.fft(np.eye(n_filters)) filters /= 0.5 * np.sqrt(kernel_size * n_filters / self.stride) # Keep only the windowed centered part to save computation. lpad = int((n_filters - kernel_size) // 2) rpad = int(n_filters - kernel_size - lpad) indexes = list(range(lpad, n_filters - rpad)) filters = np.vstack( [np.real(filters[: self.cutoff, indexes]), np.imag(filters[: self.cutoff, indexes])] ) filters[0, :] /= np.sqrt(2) filters[n_filters // 2, :] /= np.sqrt(2) filters = torch.from_numpy(filters * self.window).unsqueeze(1).float() self.register_buffer("_filters", filters) @property def filters(self): return self._filters
[docs]def perfect_synthesis_window(analysis_window, hop_size): """Computes a window for perfect synthesis given an analysis window and a hop size. Args: analysis_window (np.array): Analysis window of the transform. hop_size (int): Hop size in number of samples. Returns: np.array : the synthesis window to use for perfectly inverting the STFT. """ win_size = len(analysis_window) den = np.zeros_like(analysis_window) loop_on = (win_size - 1) // hop_size for win_idx in range(-loop_on, loop_on + 1): shifted = np.roll(analysis_window ** 2, win_idx * hop_size) if win_idx < 0: shifted[win_idx * hop_size :] = 0 elif win_idx > 0: shifted[: win_idx * hop_size] = 0 den += shifted den = np.where(den != 0.0, den, np.finfo(den.dtype).tiny) correction = int(0.5 * len(analysis_window) / hop_size) return correction * analysis_window / den
Read the Docs v: v0.3.4
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.