Source code for asteroid.filterbanks.transforms

import torch
import numpy as np

EPS = 1e-8


[docs]def mul_c(inp, other, dim=-2): """ Entrywise product for complex valued tensors. Operands are assumed to have the real parts of each entry followed by the imaginary parts of each entry along dimension `dim`, e.g. for, ``dim = 1``, the matrix .. code:: [[1, 2, 3, 4], [5, 6, 7, 8]] is interpreted as .. code:: [[1 + 3j, 2 + 4j], [5 + 7j, 6 + 8j] where `j` is such that `j * j = -1`. Args: inp (:class:`torch.Tensor`): The first operand with real and imaginary parts concatenated on the `dim` axis. other (:class:`torch.Tensor`): The second operand. dim (int, optional): frequency (or equivalent) dimension along which real and imaginary values are concatenated. Returns: :class:`torch.Tensor`: The complex multiplication between `inp` and `other` For now, it assumes that `other` has the same shape as `inp` along `dim`. """ check_complex(inp, dim=dim) check_complex(other, dim=dim) real1, imag1 = inp.chunk(2, dim=dim) real2, imag2 = other.chunk(2, dim=dim) return torch.cat([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=dim)
[docs]def take_reim(x, dim=-2): return x
[docs]def take_mag(x, dim=-2): """ Takes the magnitude of a complex tensor. The operands is assumed to have the real parts of each entry followed by the imaginary parts of each entry along dimension `dim`, e.g. for, ``dim = 1``, the matrix .. code:: [[1, 2, 3, 4], [5, 6, 7, 8]] is interpreted as .. code:: [[1 + 3j, 2 + 4j], [5 + 7j, 6 + 8j] where `j` is such that `j * j = -1`. Args: x (:class:`torch.Tensor`): Complex valued tensor. dim (int): frequency (or equivalent) dimension along which real and imaginary values are concatenated. Returns: :class:`torch.Tensor`: The magnitude of x. """ check_complex(x, dim=dim) power = torch.stack(torch.chunk(x, 2, dim=dim), dim=-1).pow(2).sum(dim=-1) power = power + EPS return power.pow(0.5)
[docs]def take_cat(x, dim=-2): return torch.cat([take_mag(x, dim=dim), x], dim=dim)
[docs]def apply_real_mask(tf_rep, mask, dim=-2): """ Applies a real-valued mask to a real-valued representation. It corresponds to ReIm mask in [1]. Args: tf_rep (:class:`torch.Tensor`): The time frequency representation to apply the mask to. mask (:class:`torch.Tensor`): The real-valued mask to be applied. dim (int): Kept to have the same interface with the other ones. Returns: :class:`torch.Tensor`: `tf_rep` multiplied by the `mask`. """ return tf_rep * mask
[docs]def apply_mag_mask(tf_rep, mask, dim=-2): """ Applies a real-valued mask to a complex-valued representation. If `tf_rep` has 2N elements along `dim`, `mask` has N elements, `mask` is duplicated along `dim` to apply the same mask to both the Re and Im. `tf_rep` is assumed to have the real parts of each entry followed by the imaginary parts of each entry along dimension `dim`, e.g. for, ``dim = 1``, the matrix .. code:: [[1, 2, 3, 4], [5, 6, 7, 8]] is interpreted as .. code:: [[1 + 3j, 2 + 4j], [5 + 7j, 6 + 8j] where `j` is such that `j * j = -1`. Args: tf_rep (:class:`torch.Tensor`): The time frequency representation to apply the mask to. Re and Im are concatenated along `dim`. mask (:class:`torch.Tensor`): The real-valued mask to be applied. dim (int): The frequency (or equivalent) dimension of both `tf_rep` and `mask` along which real and imaginary values are concatenated. Returns: :class:`torch.Tensor`: `tf_rep` multiplied by the `mask`. """ check_complex(tf_rep, dim=dim) mask = torch.cat([mask, mask], dim=dim) return tf_rep * mask
[docs]def apply_complex_mask(tf_rep, mask, dim=-2): """ Applies a complex-valued mask to a complex-valued representation. Operands are assumed to have the real parts of each entry followed by the imaginary parts of each entry along dimension `dim`, e.g. for, ``dim = 1``, the matrix .. code:: [[1, 2, 3, 4], [5, 6, 7, 8]] is interpreted as .. code:: [[1 + 3j, 2 + 4j], [5 + 7j, 6 + 8j] where `j` is such that `j * j = -1`. Args: tf_rep (:class:`torch.Tensor`): The time frequency representation to apply the mask to. mask (class:`torch.Tensor`): The complex-valued mask to be applied. dim (int): The frequency (or equivalent) dimension of both `tf_rep` an `mask` along which real and imaginary values are concatenated. Returns: :class:`torch.Tensor`: `tf_rep` multiplied by the `mask` in the complex sense. """ check_complex(tf_rep, dim=dim) return mul_c(tf_rep, mask, dim=dim)
[docs]def check_complex(tensor, dim=-2): """ Assert tensor in complex-like in a given dimension. Args: tensor (torch.Tensor): tensor to be checked. dim(int): the frequency (or equivalent) dimension along which real and imaginary values are concatenated. Raises: AssertionError if dimension is not even in the specified dimension """ if tensor.shape[dim] % 2 != 0: raise AssertionError( "Could not equally chunk the tensor (shape {}) " "along the given dimension ({}). Dim axis is " "probably wrong" )
[docs]def to_numpy(tensor, dim=-2): """ Convert complex-like torch tensor to numpy complex array Args: tensor (torch.Tensor): Complex tensor to convert to numpy. dim(int, optional): the frequency (or equivalent) dimension along which real and imaginary values are concatenated. Returns: :class:`numpy.array`: Corresponding complex array. """ check_complex(tensor, dim=dim) real, imag = torch.chunk(tensor, 2, dim=dim) return real.data.numpy() + 1j * imag.data.numpy()
[docs]def from_numpy(array, dim=-2): """ Convert complex numpy array to complex-like torch tensor. Args: array (np.array): array to be converted. dim(int, optional): the frequency (or equivalent) dimension along which real and imaginary values are concatenated. Returns: :class:`torch.Tensor`: Corresponding torch.Tensor (complex axis in dim `dim`= """ return torch.cat([torch.from_numpy(np.real(array)), torch.from_numpy(np.imag(array))], dim=dim)
[docs]def to_torchaudio(tensor, dim=-2): """ Converts complex-like torch tensor to torchaudio style complex tensor. Args: tensor (torch.tensor): asteroid-style complex-like torch tensor. dim(int, optional): the frequency (or equivalent) dimension along which real and imaginary values are concatenated. Returns: :class:`torch.Tensor`: torchaudio-style complex-like torch tensor. """ return torch.stack(torch.chunk(tensor, 2, dim=dim), dim=-1)
[docs]def from_torchaudio(tensor, dim=-2): """ Converts torchaudio style complex tensor to complex-like torch tensor. Args: tensor (torch.tensor): torchaudio-style complex-like torch tensor. dim(int, optional): the frequency (or equivalent) dimension along which real and imaginary values are concatenated. Returns: :class:`torch.Tensor`: asteroid-style complex-like torch tensor. """ return torch.cat([tensor[..., 0], tensor[..., 1]], dim=dim)
[docs]def angle(tensor, dim=-2): """ Return the angle of the complex-like torch tensor. Args: tensor (torch.Tensor): the complex tensor from which to extract the phase. dim(int, optional): the frequency (or equivalent) dimension along which real and imaginary values are concatenated. Returns: :class:`torch.Tensor`: The counterclockwise angle from the positive real axis on the complex plane in radians. """ check_complex(tensor, dim=dim) real, imag = torch.chunk(tensor, 2, dim=dim) return torch.atan2(imag, real)
[docs]def from_mag_and_phase(mag, phase, dim=-2): """ Return a complex-like torch tensor from magnitude and phase components. Args: mag (torch.tensor): magnitude of the tensor. phase (torch.tensor): angle of the tensor dim(int, optional): the frequency (or equivalent) dimension along which real and imaginary values are concatenated. Returns: :class:`torch.Tensor`: The corresponding complex-like torch tensor. """ return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=dim)
[docs]def ebased_vad(mag_spec, th_db=40): """ Compute energy-based VAD from a magnitude spectrogram (or equivalent). Args: mag_spec (torch.Tensor): the spectrogram to perform VAD on. Expected shape (batch, *, freq, time). The VAD mask will be computed independently for all the leading dimensions until the last two. Independent of the ordering of the last two dimensions. th_db (int): The threshold in dB from which a TF-bin is considered silent. Returns: torch.BoolTensor, the VAD mask. Examples: >>> import torch >>> mag_spec = torch.abs(torch.randn(10, 2, 65, 16)) >>> batch_src_mask = ebased_vad(mag_spec) """ log_mag = 20 * torch.log10(mag_spec) # Compute VAD for each utterance in a batch independently. to_view = list(mag_spec.shape[:-2]) + [1, -1] max_log_mag = torch.max(log_mag.view(to_view), -1, keepdim=True)[0] return log_mag > (max_log_mag - th_db)
_inputs = {"reim": (take_reim, 1), "mag": (take_mag, 1 / 2), "cat": (take_cat, 1 + 1 / 2)} _inputs["real"] = _inputs["reim"] _inputs["mod"] = _inputs["mag"] _inputs["concat"] = _inputs["cat"] _masks = { "reim": (apply_real_mask, 1), "mag": (apply_mag_mask, 1 / 2), "complex": (apply_complex_mask, 1), } _masks["real"] = _masks["reim"] _masks["mod"] = _masks["mag"] _masks["comp"] = _masks["complex"]