Shortcuts

Source code for asteroid.masknn.recurrent

import torch
from torch import nn
from torch.nn.functional import fold, unfold

from . import norms, activations
from .norms import GlobLN, CumLN
from ..utils import has_arg


[docs]class SingleRNN(nn.Module): """ Module for a RNN block. Inspired from https://github.com/yluo42/TAC/blob/master/utility/models.py Licensed under CC BY-NC-SA 3.0 US. Args: rnn_type (str): Select from ``'RNN'``, ``'LSTM'``, ``'GRU'``. Can also be passed in lowercase letters. input_size (int): Dimension of the input feature. The input should have shape [batch, seq_len, input_size]. hidden_size (int): Dimension of the hidden state. n_layers (int, optional): Number of layers used in RNN. Default is 1. dropout (float, optional): Dropout ratio. Default is 0. bidirectional (bool, optional): Whether the RNN layers are bidirectional. Default is ``False``. """ def __init__( self, rnn_type, input_size, hidden_size, n_layers=1, dropout=0, bidirectional=False ): super(SingleRNN, self).__init__() assert rnn_type.upper() in ["RNN", "LSTM", "GRU"] rnn_type = rnn_type.upper() self.rnn_type = rnn_type self.input_size = input_size self.hidden_size = hidden_size self.rnn = getattr(nn, rnn_type)( input_size, hidden_size, num_layers=n_layers, dropout=dropout, batch_first=True, bidirectional=bool(bidirectional), )
[docs] def forward(self, inp): """ Input shape [batch, seq, feats] """ self.rnn.flatten_parameters() # Enables faster multi-GPU training. output = inp rnn_output, _ = self.rnn(output) return rnn_output
[docs]class StackedResidualRNN(nn.Module): """ Stacked RNN with builtin residual connection. Only supports forward RNNs. See StackedResidualBiRNN for bidirectional ones. Args: rnn_type (str): Select from ``'RNN'``, ``'LSTM'``, ``'GRU'``. Can also be passed in lowercase letters. n_units (int): Number of units in recurrent layers. This will also be the expected input size. n_layers (int): Number of recurrent layers. dropout (float): Dropout value, between 0. and 1. (Default: 0.) bidirectional (bool): If True, use bidirectional RNN, else unidirectional. (Default: False) """ def __init__(self, rnn_type, n_units, n_layers=4, dropout=0.0, bidirectional=False): super(StackedResidualRNN, self).__init__() self.rnn_type = rnn_type self.n_units = n_units self.n_layers = n_layers self.dropout = dropout assert bidirectional is False, "Bidirectional not supported yet" self.bidirectional = bidirectional self.layers = nn.ModuleList() for _ in range(n_layers): self.layers.append( SingleRNN( rnn_type, input_size=n_units, hidden_size=n_units, bidirectional=bidirectional ) ) self.dropout_layer = nn.Dropout(self.dropout)
[docs] def forward(self, x): """ Builtin residual connections + dropout applied before residual. Input shape : [batch, time_axis, feat_axis] """ for rnn in self.layers: rnn_out = rnn(x) dropped_out = self.dropout_layer(rnn_out) x = x + dropped_out return x
[docs]class StackedResidualBiRNN(nn.Module): """ Stacked Bidirectional RNN with builtin residual connection. Residual connections are applied on both RNN directions. Only supports bidiriectional RNNs. See StackedResidualRNN for unidirectional ones. Args: rnn_type (str): Select from ``'RNN'``, ``'LSTM'``, ``'GRU'``. Can also be passed in lowercase letters. n_units (int): Number of units in recurrent layers. This will also be the expected input size. n_layers (int): Number of recurrent layers. dropout (float): Dropout value, between 0. and 1. (Default: 0.) bidirectional (bool): If True, use bidirectional RNN, else unidirectional. (Default: False) """ def __init__(self, rnn_type, n_units, n_layers=4, dropout=0.0, bidirectional=True): super().__init__() self.rnn_type = rnn_type self.n_units = n_units self.n_layers = n_layers self.dropout = dropout assert bidirectional is True, "Only bidirectional not supported yet" self.bidirectional = bidirectional # The first layer has as many units as input size self.first_layer = SingleRNN( rnn_type, input_size=n_units, hidden_size=n_units, bidirectional=bidirectional ) # As the first layer outputs 2*n_units, the following layers need # 2*n_units as input size self.layers = nn.ModuleList() for i in range(n_layers - 1): input_size = 2 * n_units self.layers.append( SingleRNN( rnn_type, input_size=input_size, hidden_size=n_units, bidirectional=bidirectional, ) ) self.dropout_layer = nn.Dropout(self.dropout)
[docs] def forward(self, x): """ Builtin residual connections + dropout applied before residual. Input shape : [batch, time_axis, feat_axis] """ # First layer rnn_out = self.first_layer(x) dropped_out = self.dropout_layer(rnn_out) x = torch.cat([x, x], dim=-1) + dropped_out # Rest of the layers for rnn in self.layers: rnn_out = rnn(x) dropped_out = self.dropout_layer(rnn_out) x = x + dropped_out return x
[docs]class DPRNNBlock(nn.Module): """ Dual-Path RNN Block as proposed in [1]. Args: in_chan (int): Number of input channels. hid_size (int): Number of hidden neurons in the RNNs. norm_type (str, optional): Type of normalization to use. To choose from - ``'gLN'``: global Layernorm - ``'cLN'``: channelwise Layernorm bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN. rnn_type (str, optional): Type of RNN used. Choose from ``'RNN'``, ``'LSTM'`` and ``'GRU'``. num_layers (int, optional): Number of layers used in each RNN. dropout (float, optional): Dropout ratio. Must be in [0, 1]. References: [1] "Dual-path RNN: efficient long sequence modeling for time-domain single-channel speech separation", Yi Luo, Zhuo Chen and Takuya Yoshioka. https://arxiv.org/abs/1910.06379 """ def __init__( self, in_chan, hid_size, norm_type="gLN", bidirectional=True, rnn_type="LSTM", num_layers=1, dropout=0, ): super(DPRNNBlock, self).__init__() # IntraRNN and linear projection layer (always bi-directional) self.intra_RNN = SingleRNN( rnn_type, in_chan, hid_size, num_layers, dropout=dropout, bidirectional=True ) self.intra_linear = nn.Linear(hid_size * 2, in_chan) self.intra_norm = norms.get(norm_type)(in_chan) # InterRNN block and linear projection layer (uni or bi-directional) self.inter_RNN = SingleRNN( rnn_type, in_chan, hid_size, num_layers, dropout=dropout, bidirectional=bidirectional ) num_direction = int(bidirectional) + 1 self.inter_linear = nn.Linear(hid_size * num_direction, in_chan) self.inter_norm = norms.get(norm_type)(in_chan)
[docs] def forward(self, x): """ Input shape : [batch, feats, chunk_size, num_chunks] """ B, N, K, L = x.size() output = x # for skip connection # Intra-chunk processing x = x.transpose(1, -1).reshape(B * L, K, N) x = self.intra_RNN(x) x = self.intra_linear(x) x = x.reshape(B, L, K, N).transpose(1, -1) x = self.intra_norm(x) output = output + x # Inter-chunk processing x = output.transpose(1, 2).transpose(2, -1).reshape(B * K, L, N) x = self.inter_RNN(x) x = self.inter_linear(x) x = x.reshape(B, K, L, N).transpose(1, -1).transpose(2, -1) x = self.inter_norm(x) return output + x
[docs]class DPRNN(nn.Module): """ Dual-path RNN Network for Single-Channel Source Separation introduced in [1]. Args: in_chan (int): Number of input filters. n_src (int): Number of masks to estimate. out_chan (int or None): Number of bins in the estimated masks. Defaults to `in_chan`. bn_chan (int): Number of channels after the bottleneck. Defaults to 128. hid_size (int): Number of neurons in the RNNs cell state. Defaults to 128. chunk_size (int): window size of overlap and add processing. Defaults to 100. hop_size (int or None): hop size (stride) of overlap and add processing. Default to `chunk_size // 2` (50% overlap). n_repeats (int): Number of repeats. Defaults to 6. norm_type (str, optional): Type of normalization to use. To choose from - ``'gLN'``: global Layernorm - ``'cLN'``: channelwise Layernorm mask_act (str, optional): Which non-linear function to generate mask. bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN (Intra-Chunk is always bidirectional). rnn_type (str, optional): Type of RNN used. Choose between ``'RNN'``, ``'LSTM'`` and ``'GRU'``. num_layers (int, optional): Number of layers in each RNN. dropout (float, optional): Dropout ratio, must be in [0,1]. References: [1] "Dual-path RNN: efficient long sequence modeling for time-domain single-channel speech separation", Yi Luo, Zhuo Chen and Takuya Yoshioka. https://arxiv.org/abs/1910.06379 """ def __init__( self, in_chan, n_src, out_chan=None, bn_chan=128, hid_size=128, chunk_size=100, hop_size=None, n_repeats=6, norm_type="gLN", mask_act="relu", bidirectional=True, rnn_type="LSTM", num_layers=1, dropout=0, ): super(DPRNN, self).__init__() self.in_chan = in_chan out_chan = out_chan if out_chan is not None else in_chan self.out_chan = out_chan self.bn_chan = bn_chan self.hid_size = hid_size self.chunk_size = chunk_size hop_size = hop_size if hop_size is not None else chunk_size // 2 self.hop_size = hop_size self.n_repeats = n_repeats self.n_src = n_src self.norm_type = norm_type self.mask_act = mask_act self.bidirectional = bidirectional self.rnn_type = rnn_type self.num_layers = num_layers self.dropout = dropout layer_norm = norms.get(norm_type)(in_chan) bottleneck_conv = nn.Conv1d(in_chan, bn_chan, 1) self.bottleneck = nn.Sequential(layer_norm, bottleneck_conv) # Succession of DPRNNBlocks. net = [] for x in range(self.n_repeats): net += [ DPRNNBlock( bn_chan, hid_size, norm_type=norm_type, bidirectional=bidirectional, rnn_type=rnn_type, num_layers=num_layers, dropout=dropout, ) ] self.net = nn.Sequential(*net) # Masking in 3D space net_out_conv = nn.Conv2d(bn_chan, n_src * bn_chan, 1) self.first_out = nn.Sequential(nn.PReLU(), net_out_conv) # Gating and masking in 2D space (after fold) self.net_out = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Tanh()) self.net_gate = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Sigmoid()) self.mask_net = nn.Conv1d(bn_chan, out_chan, 1, bias=False) # Get activation function. mask_nl_class = activations.get(mask_act) # For softmax, feed the source dimension. if has_arg(mask_nl_class, "dim"): self.output_act = mask_nl_class(dim=1) else: self.output_act = mask_nl_class()
[docs] def forward(self, mixture_w): """ Args: mixture_w (:class:`torch.Tensor`): Tensor of shape [batch, n_filters, n_frames] Returns: :class:`torch.Tensor` estimated mask of shape [batch, n_src, n_filters, n_frames] """ batch, n_filters, n_frames = mixture_w.size() output = self.bottleneck(mixture_w) # [batch, bn_chan, n_frames] output = unfold( output.unsqueeze(-1), kernel_size=(self.chunk_size, 1), padding=(self.chunk_size, 0), stride=(self.hop_size, 1), ) n_chunks = output.size(-1) output = output.reshape(batch, self.bn_chan, self.chunk_size, n_chunks) # Apply stacked DPRNN Blocks sequentially output = self.net(output) # Map to sources with kind of 2D masks output = self.first_out(output) output = output.reshape(batch * self.n_src, self.bn_chan, self.chunk_size, n_chunks) # Overlap and add: # [batch, out_chan, chunk_size, n_chunks] -> [batch, out_chan, n_frames] to_unfold = self.bn_chan * self.chunk_size output = fold( output.reshape(batch * self.n_src, to_unfold, n_chunks), (n_frames, 1), kernel_size=(self.chunk_size, 1), padding=(self.chunk_size, 0), stride=(self.hop_size, 1), ) # Apply gating output = output.reshape(batch * self.n_src, self.bn_chan, -1) output = self.net_out(output) * self.net_gate(output) # Compute mask score = self.mask_net(output) est_mask = self.output_act(score) est_mask = est_mask.view(batch, self.n_src, self.out_chan, n_frames) return est_mask
[docs] def get_config(self): config = { "in_chan": self.in_chan, "out_chan": self.out_chan, "bn_chan": self.bn_chan, "hid_size": self.hid_size, "chunk_size": self.chunk_size, "hop_size": self.hop_size, "n_repeats": self.n_repeats, "n_src": self.n_src, "norm_type": self.norm_type, "mask_act": self.mask_act, "bidirectional": self.bidirectional, "rnn_type": self.rnn_type, "num_layers": self.num_layers, "dropout": self.dropout, } return config
[docs]class LSTMMasker(nn.Module): """ LSTM mask network introduced in [1], without skip connections. Args: in_chan (int): Number of input filters. n_src (int): Number of masks to estimate. out_chan (int or None): Number of bins in the estimated masks. Defaults to `in_chan`. rnn_type (str, optional): Type of RNN used. Choose between ``'RNN'``, ``'LSTM'`` and ``'GRU'``. n_layers (int, optional): Number of layers in each RNN. hid_size (int): Number of neurons in the RNNs cell state. mask_act (str, optional): Which non-linear function to generate mask. bidirectional (bool, optional): Whether to use BiLSTM dropout (float, optional): Dropout ratio, must be in [0,1]. References: [1]: Yi Luo et al. "Real-time Single-channel Dereverberation and Separation with Time-domain Audio Separation Network", Interspeech 2018 """ def __init__( self, in_chan, n_src, out_chan=None, rnn_type="lstm", n_layers=4, hid_size=512, dropout=0.3, mask_act="sigmoid", bidirectional=True, ): super().__init__() self.in_chan = in_chan self.n_src = n_src out_chan = out_chan if out_chan is not None else in_chan self.out_chan = out_chan self.rnn_type = rnn_type self.n_layers = n_layers self.hid_size = hid_size self.dropout = dropout self.mask_act = mask_act self.bidirectional = bidirectional # Get activation function. mask_nl_class = activations.get(mask_act) # For softmax, feed the source dimension. if has_arg(mask_nl_class, "dim"): self.output_act = mask_nl_class(dim=1) else: self.output_act = mask_nl_class() # Create TasNet masker out_size = hid_size * (int(bidirectional) + 1) if bidirectional: self.bn_layer = GlobLN(in_chan) else: self.bn_layer = CumLN(in_chan) self.masker = nn.Sequential( SingleRNN( "lstm", in_chan, hidden_size=hid_size, n_layers=n_layers, bidirectional=bidirectional, dropout=dropout, ), nn.Linear(out_size, self.n_src * out_chan), self.output_act, )
[docs] def forward(self, x): batch_size = x.shape[0] to_sep = self.bn_layer(x) est_masks = self.masker(to_sep.transpose(-1, -2)).transpose(-1, -2) est_masks = est_masks.view(batch_size, self.n_src, self.out_chan, -1) return est_masks
[docs] def get_config(self): config = { "in_chan": self.in_chan, "n_src": self.n_src, "out_chan": self.out_chan, "rnn_type": self.rnn_type, "n_layers": self.n_layers, "hid_size": self.hid_size, "dropout": self.dropout, "mask_act": self.mask_act, "bidirectional": self.bidirectional, } return config
Read the Docs v: v0.3.3
Versions
latest
stable
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.