Shortcuts

Source code for asteroid.models.conv_tasnet

from ..filterbanks import make_enc_dec
from ..masknn import TDConvNet
from .base_models import BaseTasNet


[docs]class ConvTasNet(BaseTasNet): """ ConvTasNet separation model, as described in [1]. Args: n_src (int): Number of sources in the input mixtures. out_chan (int, optional): Number of bins in the estimated masks. If ``None``, `out_chan = in_chan`. n_blocks (int, optional): Number of convolutional blocks in each repeat. Defaults to 8. n_repeats (int, optional): Number of repeats. Defaults to 3. bn_chan (int, optional): Number of channels after the bottleneck. hid_chan (int, optional): Number of channels in the convolutional blocks. skip_chan (int, optional): Number of channels in the skip connections. If 0 or None, TDConvNet won't have any skip connections and the masks will be computed from the residual output. Corresponds to the ConvTasnet architecture in v1 or the paper. conv_kernel_size (int, optional): Kernel size in convolutional blocks. norm_type (str, optional): To choose from ``'BN'``, ``'gLN'``, ``'cLN'``. mask_act (str, optional): Which non-linear function to generate mask. in_chan (int, optional): Number of input channels, should be equal to n_filters. fb_name (str, className): Filterbank family from which to make encoder and decoder. To choose among [``'free'``, ``'analytic_free'``, ``'param_sinc'``, ``'stft'``]. n_filters (int): Number of filters / Input dimension of the masker net. kernel_size (int): Length of the filters. stride (int, optional): Stride of the convolution. If None (default), set to ``kernel_size // 2``. **fb_kwargs (dict): Additional kwards to pass to the filterbank creation. References: [1] : "Conv-TasNet: Surpassing ideal time-frequency magnitude masking for speech separation" TASLP 2019 Yi Luo, Nima Mesgarani https://arxiv.org/abs/1809.07454 """ def __init__( self, n_src, out_chan=None, n_blocks=8, n_repeats=3, bn_chan=128, hid_chan=512, skip_chan=128, conv_kernel_size=3, norm_type="gLN", mask_act="sigmoid", in_chan=None, fb_name="free", kernel_size=16, n_filters=512, stride=8, encoder_activation="relu", **fb_kwargs, ): encoder, decoder = make_enc_dec( fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, **fb_kwargs ) n_feats = encoder.n_feats_out if in_chan is not None: assert in_chan == n_feats, ( "Number of filterbank output channels" " and number of input channels should " "be the same. Received " f"{n_feats} and {in_chan}" ) # Update in_chan masker = TDConvNet( n_feats, n_src, out_chan=out_chan, n_blocks=n_blocks, n_repeats=n_repeats, bn_chan=bn_chan, hid_chan=hid_chan, skip_chan=skip_chan, conv_kernel_size=conv_kernel_size, norm_type=norm_type, mask_act=mask_act, ) super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
Read the Docs v: v0.3.3
Versions
latest
stable
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.