asteroid.models.base_models module¶
-
class
asteroid.models.base_models.
BaseTasNet
(encoder, masker, decoder, encoder_activation=None)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Base class for encoder-masker-decoder separation models.
Parameters: -
forward
(wav)[source]¶ Enc/Mask/Dec model forward
Parameters: wav (torch.Tensor) – waveform tensor. 1D, 2D or 3D tensor, time last. Returns: torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
-
classmethod
from_pretrained
(pretrained_model_conf_or_path, *args, **kwargs)[source]¶ Instantiate separation model from a model config (file or dict).
Parameters: pretrained_model_conf_or_path (Union[dict, str]) – model conf as returned by serialize, or path to it. Need to contain model_args and state_dict keys. Returns: Instance of BaseTasNet Raises: ValueError if the input config file doesn’t contain the keys – model_args and state_dict.
-
separate
(wav)[source]¶ Infer separated sources from input waveforms. Also supports filenames.
Parameters: wav (Union[torch.Tensor, numpy.ndarray, str]) – waveform array/tensor. Shape: 1D, 2D or 3D tensor, time last. Returns: - Union[torch.Tensor, numpy.ndarray, None], the estimated sources.
- (batch, n_src, time) or (n_src, time) w/o batch dim.
-