Source code for asteroid.models.base_models

import torch
from torch import nn
import numpy as np

from .. import torch_utils
from ..utils.hub_utils import cached_download
from ..masknn import activations


[docs]class BaseTasNet(nn.Module): """ Base class for encoder-masker-decoder separation models. Args: encoder (Encoder): Encoder instance. masker (nn.Module): masker network. decoder (Decoder): Decoder instance. """ def __init__(self, encoder, masker, decoder, encoder_activation=None): super().__init__() self.encoder = encoder self.masker = masker self.decoder = decoder self.encoder_activation = encoder_activation if encoder_activation: self.enc_activation = activations.get(encoder_activation)() else: self.enc_activation = activations.get("linear")()
[docs] def forward(self, wav): """ Enc/Mask/Dec model forward Args: wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last. Returns: torch.Tensor, of shape (batch, n_src, time) or (n_src, time). """ # Handle 1D, 2D or n-D inputs was_one_d = False if wav.ndim == 1: was_one_d = True wav = wav.unsqueeze(0).unsqueeze(1) if wav.ndim == 2: wav = wav.unsqueeze(1) # Real forward tf_rep = self.enc_activation(self.encoder(wav)) est_masks = self.masker(tf_rep) masked_tf_rep = est_masks * tf_rep.unsqueeze(1) out_wavs = torch_utils.pad_x_to_y(self.decoder(masked_tf_rep), wav) if was_one_d: return out_wavs.squeeze(0) return out_wavs
[docs] def separate(self, wav): """ Infer separated sources from input waveforms. Also supports filenames. Args: wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor. Shape: 1D, 2D or 3D tensor, time last. Returns: Union[torch.Tensor, numpy.ndarray, None], the estimated sources. (batch, n_src, time) or (n_src, time) w/o batch dim. """ with torch.no_grad(): return self._separate(wav)
def _separate(self, wav): """ Hidden separation method Args: wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor. Shape: 1D, 2D or 3D tensor, time last. Returns: Union[torch.Tensor, numpy.ndarray, None], the estimated sources. (batch, n_src, time) or (n_src, time) w/o batch dim. """ # Handle filename inputs was_file = False if isinstance(wav, str): import soundfile as sf was_file = True filename = wav wav, fs = sf.read(wav, dtype="float32") wav = torch.from_numpy(wav) # Handle numpy inputs was_numpy = False if isinstance(wav, np.ndarray): was_numpy = True wav = torch.from_numpy(wav) # Handle device placement input_device = wav.device model_device = next(self.parameters()).device wav = wav.to(model_device) # Forward out_wavs = self.forward(wav) # Back to input device (and numpy if necessary) out_wavs = out_wavs.to(input_device) if was_numpy: return out_wavs.cpu().data.numpy() if was_file: # Save wav files to filename_est1.wav etc... to_save = out_wavs.cpu().data.numpy() for src_idx, est_src in enumerate(to_save): base = ".".join(filename.split(".")[:-1]) save_name = base + "_est{}.".format(src_idx + 1) + filename.split(".")[-1] sf.write(save_name, est_src, fs) return return out_wavs
[docs] @classmethod def from_pretrained(cls, pretrained_model_conf_or_path, *args, **kwargs): """ Instantiate separation model from a model config (file or dict). Args: pretrained_model_conf_or_path (Union[dict, str]): model conf as returned by `serialize`, or path to it. Need to contain `model_args` and `state_dict` keys. Returns: Instance of BaseTasNet Raises: ValueError if the input config file doesn't contain the keys `model_args` and `state_dict`. """ if isinstance(pretrained_model_conf_or_path, str): cached_model = cached_download(pretrained_model_conf_or_path) conf = torch.load(cached_model, map_location="cpu") else: conf = pretrained_model_conf_or_path if "model_args" not in conf.keys(): raise ValueError( "Expected config dictionary to have field " "model_args`. Found only: {}".format(conf.keys()) ) if "state_dict" not in conf.keys(): raise ValueError( "Expected config dictionary to have field " "state_dict`. Found only: {}".format(conf.keys()) ) model = cls(*args, **conf["model_args"], **kwargs) model.load_state_dict(conf["state_dict"]) return model
[docs] def serialize(self): """ Serialize model and output dictionary. Returns: dict, serialized model with keys `model_args` and `state_dict`. """ from .. import __version__ as asteroid_version # Avoid circular imports import pytorch_lightning as pl # Not used in torch.hub model_conf = dict() fb_config = self.encoder.filterbank.get_config() masknet_config = self.masker.get_config() # Assert both dict are disjoint if not all(k not in fb_config for k in masknet_config): raise AssertionError( "Filterbank and Mask network config share" "common keys. Merging them is not safe." ) # Merge all args under model_args. model_conf["model_name"] = self.__class__.__name__ model_conf["model_args"] = { **fb_config, **masknet_config, "encoder_activation": self.encoder_activation, } model_conf["state_dict"] = self.state_dict() # Additional infos infos = dict() infos["software_versions"] = dict( torch_version=torch.__version__, pytorch_lightning_version=pl.__version__, asteroid_version=asteroid_version, ) model_conf["infos"] = infos return model_conf