import torch
from torch import nn
import math
from ..filterbanks import make_enc_dec
from ..masknn import SuDORMRF, SuDORMRFImproved
from .base_models import BaseTasNet
[docs]class SuDORMRFNet(BaseTasNet):
""" SuDORMRF separation model, as described in [1].
Args:
n_src (int): Number of sources in the input mixtures.
bn_chan (int, optional): Number of bins in the bottleneck layer and the UNet blocks.
num_blocks (int): Number of of UBlocks.
upsampling_depth (int): Depth of upsampling.
mask_act (str): Name of output activation.
in_chan (int, optional): Number of input channels, should be equal to
n_filters.
fb_name (str, className): Filterbank family from which to make encoder
and decoder. To choose among [``'free'``, ``'analytic_free'``,
``'param_sinc'``, ``'stft'``].
n_filters (int): Number of filters / Input dimension of the masker net.
kernel_size (int): Length of the filters.
stride (int, optional): Stride of the convolution.
If None (default), set to ``kernel_size // 2``.
**fb_kwargs (dict): Additional kwards to pass to the filterbank
creation.
References:
[1] : "Sudo rm -rf: Efficient Networks for Universal Audio Source Separation",
Tzinis et al. MLSP 2020.
"""
def __init__(
self,
n_src,
bn_chan=128,
num_blocks=16,
upsampling_depth=4,
mask_act="softmax",
in_chan=None,
fb_name="free",
kernel_size=21,
n_filters=512,
stride=None,
**fb_kwargs,
):
# Need the encoder to determine the number of input channels
stride = kernel_size // 2 if not stride else stride
enc, dec = make_enc_dec(
fb_name,
kernel_size=kernel_size,
n_filters=n_filters,
stride=kernel_size // 2,
padding=kernel_size // 2,
output_padding=(kernel_size // 2) - 1,
**fb_kwargs,
)
n_feats = enc.n_feats_out
enc = _Padder(enc, upsampling_depth=upsampling_depth, kernel_size=kernel_size)
if in_chan is not None:
assert in_chan == n_feats, (
"Number of filterbank output channels"
" and number of input channels should "
"be the same. Received "
f"{n_feats} and {in_chan}"
)
masker = SuDORMRF(
n_feats,
n_src,
bn_chan=bn_chan,
num_blocks=num_blocks,
upsampling_depth=upsampling_depth,
mask_act=mask_act,
)
super().__init__(enc, masker, dec, encoder_activation="relu")
[docs]class SuDORMRFImprovedNet(BaseTasNet):
""" Improved SuDORMRF separation model, as described in [1].
Args:
n_src (int): Number of sources in the input mixtures.
bn_chan (int, optional): Number of bins in the bottleneck layer and the UNet blocks.
num_blocks (int): Number of of UBlocks.
upsampling_depth (int): Depth of upsampling.
mask_act (str): Name of output activation.
in_chan (int, optional): Number of input channels, should be equal to
n_filters.
fb_name (str, className): Filterbank family from which to make encoder
and decoder. To choose among [``'free'``, ``'analytic_free'``,
``'param_sinc'``, ``'stft'``].
n_filters (int): Number of filters / Input dimension of the masker net.
kernel_size (int): Length of the filters.
stride (int, optional): Stride of the convolution.
If None (default), set to ``kernel_size // 2``.
**fb_kwargs (dict): Additional kwards to pass to the filterbank
creation.
References:
[1] : "Sudo rm -rf: Efficient Networks for Universal Audio Source Separation",
Tzinis et al. MLSP 2020.
"""
def __init__(
self,
n_src,
bn_chan=128,
num_blocks=16,
upsampling_depth=4,
mask_act="relu",
in_chan=None,
fb_name="free",
kernel_size=21,
n_filters=512,
stride=None,
**fb_kwargs,
):
stride = kernel_size // 2 if not stride else stride
# Need the encoder to determine the number of input channels
enc, dec = make_enc_dec(
fb_name,
kernel_size=kernel_size,
n_filters=n_filters,
stride=stride,
padding=kernel_size // 2,
output_padding=(kernel_size // 2) - 1,
**fb_kwargs,
)
n_feats = enc.n_feats_out
enc = _Padder(enc, upsampling_depth=upsampling_depth, kernel_size=kernel_size)
if in_chan is not None:
assert in_chan == n_feats, (
"Number of filterbank output channels"
" and number of input channels should "
"be the same. Received "
f"{n_feats} and {in_chan}"
)
masker = SuDORMRFImproved(
n_feats,
n_src,
bn_chan=bn_chan,
num_blocks=num_blocks,
upsampling_depth=upsampling_depth,
mask_act=mask_act,
)
super().__init__(enc, masker, dec, encoder_activation=None)
class _Padder(nn.Module):
def __init__(self, encoder, upsampling_depth=4, kernel_size=21):
super().__init__()
self.encoder = encoder
self.upsampling_depth = upsampling_depth
self.kernel_size = kernel_size
# Appropriate padding is needed for arbitrary lengths
self.lcm = abs(self.kernel_size // 2 * 2 ** self.upsampling_depth) // math.gcd(
self.kernel_size // 2, 2 ** self.upsampling_depth
)
# For serialize
self.filterbank = self.encoder.filterbank
def forward(self, x):
x = self.pad(x)
return self.encoder(x)
def pad(self, x):
values_to_pad = int(x.shape[-1]) % self.lcm
if values_to_pad:
appropriate_shape = x.shape
padded_x = torch.zeros(
list(appropriate_shape[:-1]) + [appropriate_shape[-1] + self.lcm - values_to_pad],
dtype=torch.float32,
)
padded_x[..., : x.shape[-1]] = x
return padded_x
return x