Shortcuts

Source code for asteroid.filterbanks.multiphase_gammatone_fb

import numpy as np
import torch
from .enc_dec import Filterbank


[docs]class MultiphaseGammatoneFB(Filterbank): """ Multi-Phase Gammatone Filterbank as described in [1]. Please cite [1] whenever using this. Original code repository: `<https://github.com/sp-uhh/mp-gtf>` Args: n_filters (int): Number of filters. kernel_size (int): Length of the filters. sample_rate (int, optional): The sample rate (used for initialization). stride (int, optional): Stride of the convolution. If None (default), set to ``kernel_size // 2``. References: [1] David Ditter, Timo Gerkmann, "A Multi-Phase Gammatone Filterbank for Speech Separation via TasNet", ICASSP 2020 Available: `<https://ieeexplore.ieee.org/document/9053602/>` """ def __init__(self, n_filters=128, kernel_size=16, sample_rate=8000, stride=None, **kwargs): super().__init__(n_filters, kernel_size, stride=stride) self.sample_rate = sample_rate self.n_feats_out = n_filters length_in_seconds = kernel_size / sample_rate mpgtf = generate_mpgtf(sample_rate, length_in_seconds, n_filters) filters = torch.from_numpy(mpgtf).unsqueeze(1).float() self.register_buffer("_filters", filters) @property def filters(self): return self._filters
[docs]def generate_mpgtf(samplerate_hz, len_sec, n_filters): # Set parameters center_freq_hz_min = 100 n_center_freqs = 24 len_sample = int(np.floor(samplerate_hz * len_sec)) # Initialize variables index = 0 filterbank = np.zeros((n_filters, len_sample)) current_center_freq_hz = center_freq_hz_min # Determine number of phase shifts per center frequency phase_pair_count = (np.ones(n_center_freqs) * np.floor(n_filters / 2 / n_center_freqs)).astype( int ) remaining_phase_pairs = ((n_filters - np.sum(phase_pair_count) * 2) / 2).astype(int) if remaining_phase_pairs > 0: phase_pair_count[:remaining_phase_pairs] = phase_pair_count[:remaining_phase_pairs] + 1 # Generate all filters for each center frequencies for i in range(n_center_freqs): # Generate all filters for all phase shifts for phase_index in range(phase_pair_count[i]): # First half of filtes: phase_shifts in [0,pi) current_phase_shift = np.float(phase_index) / phase_pair_count[i] * np.pi filterbank[index, :] = gammatone_impulse_response( samplerate_hz, len_sec, current_center_freq_hz, current_phase_shift, ) index = index + 1 # Second half of filters: phase_shifts in [pi, 2*pi) filterbank[index : index + phase_pair_count[i], :] = -filterbank[ index - phase_pair_count[i] : index, : ] # Prepare for next center frequency index = index + phase_pair_count[i] current_center_freq_hz = erb_scale_2_freq_hz( freq_hz_2_erb_scale(current_center_freq_hz) + 1 ) filterbank = normalize_filters(filterbank) return filterbank
[docs]def gammatone_impulse_response(samplerate_hz, len_sec, center_freq_hz, phase_shift): """ Generate single parametrized gammatone filter """ p = 2 # filter order erb = 24.7 + 0.108 * center_freq_hz # equivalent rectangular bandwidth divisor = (np.pi * np.math.factorial(2 * p - 2) * np.power(2, float(-(2 * p - 2)))) / np.square( np.math.factorial(p - 1) ) b = erb / divisor # bandwidth parameter a = 1.0 # amplitude. This is varied later by the normalization process. len_sample = int(np.floor(samplerate_hz * len_sec)) t = np.linspace(1.0 / samplerate_hz, len_sec, len_sample) gammatone_ir = ( a * np.power(t, p - 1) * np.exp(-2 * np.pi * b * t) * np.cos(2 * np.pi * center_freq_hz * t + phase_shift) ) return gammatone_ir
[docs]def erb_scale_2_freq_hz(freq_erb): """ Convert frequency on ERB scale to frequency in Hertz """ freq_hz = (np.exp(freq_erb / 9.265) - 1) * 24.7 * 9.265 return freq_hz
[docs]def freq_hz_2_erb_scale(freq_hz): """ Convert frequency in Hertz to frequency on ERB scale """ freq_erb = 9.265 * np.log(1 + freq_hz / (24.7 * 9.265)) return freq_erb
[docs]def normalize_filters(filterbank): """ Normalizes a filterbank such that all filters have the same root mean square (RMS). """ rms_per_filter = np.sqrt(np.mean(np.square(filterbank), axis=1)) rms_normalization_values = 1.0 / (rms_per_filter / np.amax(rms_per_filter)) normalized_filterbank = filterbank * rms_normalization_values[:, np.newaxis] return normalized_filterbank
Read the Docs v: v0.3.3
Versions
latest
stable
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.