Shortcuts

Source code for asteroid.models.lstm_tasnet

import torch
from torch import nn
from copy import deepcopy

from ..filterbanks import make_enc_dec
from ..masknn import LSTMMasker
from .base_models import BaseTasNet


[docs]class LSTMTasNet(BaseTasNet): """ TasNet separation model, as described in [1]. Args: n_src (int): Number of masks to estimate. out_chan (int or None): Number of bins in the estimated masks. Defaults to `in_chan`. hid_size (int): Number of neurons in the RNNs cell state. Defaults to 128. mask_act (str, optional): Which non-linear function to generate mask. bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN (Intra-Chunk is always bidirectional). rnn_type (str, optional): Type of RNN used. Choose between ``'RNN'``, ``'LSTM'`` and ``'GRU'``. n_layers (int, optional): Number of layers in each RNN. dropout (float, optional): Dropout ratio, must be in [0,1]. in_chan (int, optional): Number of input channels, should be equal to n_filters. fb_name (str, className): Filterbank family from which to make encoder and decoder. To choose among [``'free'``, ``'analytic_free'``, ``'param_sinc'``, ``'stft'``]. n_filters (int): Number of filters / Input dimension of the masker net. kernel_size (int): Length of the filters. stride (int, optional): Stride of the convolution. If None (default), set to ``kernel_size // 2``. **fb_kwargs (dict): Additional kwards to pass to the filterbank creation. References: [1]: Yi Luo et al. "Real-time Single-channel Dereverberation and Separation with Time-domain Audio Separation Network", Interspeech 2018 """ def __init__( self, n_src, out_chan=None, rnn_type="lstm", n_layers=4, hid_size=512, dropout=0.3, mask_act="sigmoid", bidirectional=True, in_chan=None, fb_name="free", n_filters=64, kernel_size=16, stride=8, encoder_activation=None, **fb_kwargs, ): encoder, decoder = make_enc_dec( fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, **fb_kwargs ) n_feats = encoder.n_feats_out if in_chan is not None: assert in_chan == n_feats, ( "Number of filterbank output channels" " and number of input channels should " "be the same. Received " f"{n_feats} and {in_chan}" ) # Real gated encoder encoder = _GatedEncoder(encoder) # Masker masker = LSTMMasker( n_feats, n_src, out_chan=out_chan, hid_size=hid_size, mask_act=mask_act, bidirectional=bidirectional, rnn_type=rnn_type, n_layers=n_layers, dropout=dropout, ) super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
class _GatedEncoder(nn.Module): def __init__(self, encoder): super().__init__() # For config self.filterbank = encoder.filterbank # Gated encoder. self.encoder_relu = encoder self.encoder_sig = deepcopy(encoder) def forward(self, x): relu_out = torch.relu(self.encoder_relu(x)) sig_out = torch.sigmoid(self.encoder_sig(x)) return sig_out * relu_out
Read the Docs v: v0.3.3
Versions
latest
stable
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.