import torch
from scipy.signal import get_window
from asteroid.losses import PITLossWrapper
from torch import nn
[docs]class LambdaOverlapAdd(torch.nn.Module):
""" Segment signal, apply func, combine with OLA.
Args:
nnet (callable): function to apply to each segment.
n_src (int): Number of sources in the output of nnet.
window_size (int): Size of segmenting window.
hop_size (int): segmentation hop size.
window (str): Name of the window (see scipy.signal.get_window)
reorder_chunks (bool): whether to reorder each consecutive segment.
Examples:
>>> from asteroid import ConvTasNet
>>> nnet = ConvTasNet(n_src=2)
>>> continuous_nnet = LambdaOverlapAdd(
>>> nnet=nnet,
>>> n_src=2,
>>> window_size=64000,
>>> hop_size=None,
>>> window="hanning",
>>> reorder_chunks=True,
>>> enable_grad=False,
>>> )
>>> wav = torch.randn(1, 1, 500000)
>>> out_wavs = continuous_nnet.forward(wav)
"""
def __init__(
self,
nnet,
n_src,
window_size,
hop_size=None,
window="hanning",
reorder_chunks=True,
enable_grad=False,
):
super().__init__()
assert window_size % 2 == 0, "Window size must be even"
self.nnet = nnet
self.window_size = window_size
self.hop_size = hop_size if hop_size is not None else window_size // 2
self.n_src = n_src
if window:
window = get_window(window, self.window_size).astype("float32")
window = torch.from_numpy(window)
self.use_window = True
else:
self.use_window = False
self.register_buffer("window", window)
self.reorder_chunks = reorder_chunks
self.enable_grad = enable_grad
[docs] def ola_forward(self, x):
"""Heart of the class: segment signal, apply func, combine with OLA."""
assert x.ndim == 3
batch, channels, n_frames = x.size()
# Overlap and add:
# [batch, chans, n_frames] -> [batch, chans, win_size, n_chunks]
folded = torch.nn.functional.unfold(
x.unsqueeze(-1),
kernel_size=(self.window_size, 1),
padding=(self.window_size, 0),
stride=(self.hop_size, 1),
)
out = []
n_chunks = folded.shape[-1]
for frame_idx in range(n_chunks): # for loop to spare memory
tmp = self.nnet(folded[..., frame_idx])
# user must handle multichannel by reshaping to batch
if frame_idx == 0:
assert tmp.ndim == 3, "nnet should return (batch, n_src, time)"
assert tmp.shape[1] == self.n_src, "nnet should return (batch, n_src, time)"
tmp = tmp.reshape(batch * self.n_src, -1)
if frame_idx != 0 and self.reorder_chunks:
# we determine best perm based on xcorr with previous sources
tmp = _reorder_sources(tmp, out[-1], self.n_src, self.window_size, self.hop_size)
if self.use_window:
tmp = tmp * self.window
else:
tmp = tmp / (self.window_size / self.hop_size)
out.append(tmp)
out = torch.stack(out).reshape(n_chunks, batch * self.n_src, self.window_size)
out = out.permute(1, 2, 0)
out = torch.nn.functional.fold(
out,
(n_frames, 1),
kernel_size=(self.window_size, 1),
padding=(self.window_size, 0),
stride=(self.hop_size, 1),
)
return out.squeeze(-1).reshape(batch, self.n_src, -1)
[docs] def forward(self, x):
""" Forward module: segment signal, apply func, combine with OLA.
Args:
x (:class:`torch.Tensor`): waveform signal of shape (batch, 1, time).
Returns:
:class:`torch.Tensor`: The output of the lambda OLA.
"""
# Here we can do the reshaping
with torch.autograd.set_grad_enabled(self.enable_grad):
olad = self.ola_forward(x)
return olad
def _reorder_sources(
current: torch.FloatTensor,
previous: torch.FloatTensor,
n_src: int,
window_size: int,
hop_size: int,
):
"""
Reorder sources in current chunk to maximize correlation with previous chunk.
Used for Continuous Source Separation. Standard dsp correlation is used
for reordering.
Args:
current (:class:`torch.Tensor`): current chunk, tensor
of shape (batch, n_src, window_size)
previous (:class:`torch.Tensor`): previous chunk, tensor
of shape (batch, n_src, window_size)
n_src (:class:`int`): number of sources.
window_size (:class:`int`): window_size, equal to last dimension of
both current and previous.
hop_size (:class:`int`): hop_size between current and previous tensors.
Returns:
current:
"""
batch, frames = current.size()
current = current.reshape(-1, n_src, frames)
previous = previous.reshape(-1, n_src, frames)
overlap_f = window_size - hop_size
pw_losses = PITLossWrapper.get_pw_losses(
lambda x, y: torch.sum((x.unsqueeze(1) * y.unsqueeze(2))),
current[..., :overlap_f],
previous[..., -overlap_f:],
)
_, perms = PITLossWrapper.find_best_perm(pw_losses, n_src)
current = PITLossWrapper.reorder_source(current, n_src, perms)
return current.reshape(batch, frames)
[docs]class DualPathProcessing(nn.Module):
""" Perform Dual-Path processing via overlap-add as in DPRNN [1].
Args:
chunk_size (int): Size of segmenting window.
hop_size (int): segmentation hop size.
References:
[1] "Dual-path RNN: efficient long sequence modeling for
time-domain single-channel speech separation", Yi Luo, Zhuo Chen
and Takuya Yoshioka. https://arxiv.org/abs/1910.06379
"""
def __init__(self, chunk_size, hop_size):
super(DualPathProcessing, self).__init__()
self.chunk_size = chunk_size
self.hop_size = hop_size
self.n_orig_frames = None
[docs] def unfold(self, x):
""" Unfold the feature tensor from
(batch, channels, time) to (batch, channels, chunk_size, n_chunks).
Args:
x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time).
Returns:
x: (:class:`torch.Tensor`): spliced feature tensor of shape
(batch, channels, chunk_size, n_chunks).
"""
# x is (batch, chan, frames)
batch, chan, frames = x.size()
assert x.ndim == 3
self.n_orig_frames = x.shape[-1]
unfolded = torch.nn.functional.unfold(
x.unsqueeze(-1),
kernel_size=(self.chunk_size, 1),
padding=(self.chunk_size, 0),
stride=(self.hop_size, 1),
)
return unfolded.reshape(
batch, chan, self.chunk_size, -1
) # (batch, chan, chunk_size, n_chunks)
[docs] def fold(self, x, output_size=None):
""" Folds back the spliced feature tensor.
Input shape (batch, channels, chunk_size, n_chunks) to original shape
(batch, channels, time) using overlap-add.
Args:
x: (:class:`torch.Tensor`): spliced feature tensor of shape
(batch, channels, chunk_size, n_chunks).
output_size: (int, optional): sequence length of original feature tensor.
If None, the original length cached by the previous call of `unfold`
will be used.
Returns:
x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time).
.. note:: `fold` caches the original length of the pr
"""
output_size = output_size if output_size is not None else self.n_orig_frames
# x is (batch, chan, chunk_size, n_chunks)
batch, chan, chunk_size, n_chunks = x.size()
to_unfold = x.reshape(batch, chan * self.chunk_size, n_chunks)
x = torch.nn.functional.fold(
to_unfold,
(output_size, 1),
kernel_size=(self.chunk_size, 1),
padding=(self.chunk_size, 0),
stride=(self.hop_size, 1),
)
x /= self.chunk_size / self.hop_size
return x.reshape(batch, chan, self.n_orig_frames)
[docs] @staticmethod
def intra_process(x, module):
""" Performs intra-chunk processing.
Args:
x (:class:`torch.Tensor`): spliced feature tensor of shape
(batch, channels, chunk_size, n_chunks).
module (:class:`torch.nn.Module`): module one wish to apply to each chunk
of the spliced feature tensor.
Returns:
x (:class:`torch.Tensor`): processed spliced feature tensor of shape
(batch, channels, chunk_size, n_chunks).
.. note:: the module should have the channel first convention and accept
a 3D tensor of shape (batch, channels, time).
"""
# x is (batch, channels, chunk_size, n_chunks)
batch, channels, chunk_size, n_chunks = x.size()
# we reshape to batch*chunk_size, channels, n_chunks
x = x.transpose(1, -1).reshape(batch * n_chunks, chunk_size, channels).transpose(1, -1)
x = module(x)
x = x.reshape(batch, n_chunks, channels, chunk_size).transpose(1, -1).transpose(1, 2)
return x
[docs] @staticmethod
def inter_process(x, module):
""" Performs inter-chunk processing.
Args:
x (:class:`torch.Tensor`): spliced feature tensor of shape
(batch, channels, chunk_size, n_chunks).
module (:class:`torch.nn.Module`): module one wish to apply between
each chunk of the spliced feature tensor.
Returns:
x (:class:`torch.Tensor`): processed spliced feature tensor of shape
(batch, channels, chunk_size, n_chunks).
.. note:: the module should have the channel first convention and accept
a 3D tensor of shape (batch, channels, time).
"""
batch, channels, chunk_size, n_chunks = x.size()
x = x.transpose(1, 2).reshape(batch * chunk_size, channels, n_chunks)
x = module(x)
x = x.reshape(batch, chunk_size, channels, n_chunks).transpose(1, 2)
return x