
Source code for asteroid.filterbanks

from .analytic_free_fb import AnalyticFreeFB
from .free_fb import FreeFB
from .param_sinc_fb import ParamSincFB
from .stft_fb import STFTFB
from .enc_dec import Filterbank, Encoder, Decoder
from .griffin_lim import griffin_lim, misi
from .multiphase_gammatone_fb import MultiphaseGammatoneFB

[docs]def make_enc_dec( fb_name, n_filters, kernel_size, stride=None, who_is_pinv=None, padding=0, output_padding=0, **kwargs, ): """Creates congruent encoder and decoder from the same filterbank family. Args: fb_name (str, className): Filterbank family from which to make encoder and decoder. To choose among [``'free'``, ``'analytic_free'``, ``'param_sinc'``, ``'stft'``]. Can also be a class defined in a submodule in this subpackade (e.g. :class:`~.FreeFB`). n_filters (int): Number of filters. kernel_size (int): Length of the filters. stride (int, optional): Stride of the convolution. If None (default), set to ``kernel_size // 2``. who_is_pinv (str, optional): If `None`, no pseudo-inverse filters will be used. If string (among [``'encoder'``, ``'decoder'``]), decides which of ``Encoder`` or ``Decoder`` will be the pseudo inverse of the other one. padding (int): Zero-padding added to both sides of the input. Passed to Encoder and Decoder. output_padding (int): Additional size added to one side of the output shape. Passed to Decoder. **kwargs: Arguments which will be passed to the filterbank class additionally to the usual `n_filters`, `kernel_size` and `stride`. Depends on the filterbank family. Returns: :class:`.Encoder`, :class:`.Decoder` """ fb_class = get(fb_name) if who_is_pinv in ["dec", "decoder"]: fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs) enc = Encoder(fb, padding=padding) # Decoder filterbank is pseudo inverse of encoder filterbank. dec = Decoder.pinv_of(fb) elif who_is_pinv in ["enc", "encoder"]: fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs) dec = Decoder(fb, padding=padding, output_padding=output_padding) # Encoder filterbank is pseudo inverse of decoder filterbank. enc = Encoder.pinv_of(fb) else: fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs) enc = Encoder(fb, padding=padding) # Filters between encoder and decoder should not be shared. fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs) dec = Decoder(fb, padding=padding, output_padding=output_padding) return enc, dec
def register_filterbank(custom_fb): """Register a custom filterbank, gettable with `filterbanks.get`. Args: custom_fb: Custom filterbank to register. """ if custom_fb.__name__ in globals().keys() or custom_fb.__name__.lower() in globals().keys(): raise ValueError(f"Filterbank {custom_fb.__name__} already exists. Choose another name.") globals().update({custom_fb.__name__: custom_fb})
[docs]def get(identifier): """Returns a filterbank class from a string. Returns its input if it is callable (already a :class:`.Filterbank` for example). Args: identifier (str or Callable or None): the filterbank identifier. Returns: :class:`.Filterbank` or None """ if identifier is None: return None elif callable(identifier): return identifier elif isinstance(identifier, str): cls = globals().get(identifier) if cls is None: raise ValueError("Could not interpret filterbank identifier: " + str(identifier)) return cls else: raise ValueError("Could not interpret filterbank identifier: " + str(identifier))
# Aliases. free = FreeFB analytic_free = AnalyticFreeFB param_sinc = ParamSincFB stft = STFTFB multiphase_gammatone = mpgtf = MultiphaseGammatoneFB # For the docs __all__ = [ "Filterbank", "Encoder", "Decoder", "FreeFB", "STFTFB", "AnalyticFreeFB", "ParamSincFB", "MultiphaseGammatoneFB", "griffin_lim", "misi", "make_enc_dec", ]
Read the Docs v: v0.3.4
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.