asteroid.models.base_models module¶
-
class
asteroid.models.base_models.
BaseModel
[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
-
file_separate
(filename: str, output_dir=None, force_overwrite=False, **kwargs) → None[source]¶ Filename interface to separate.
-
classmethod
from_pretrained
(pretrained_model_conf_or_path, *args, **kwargs)[source]¶ Instantiate separation model from a model config (file or dict).
Parameters: - pretrained_model_conf_or_path (Union[dict, str]) – model conf as returned by serialize, or path to it. Need to contain model_args and state_dict keys.
- *args – Positional arguments to be passed to the model.
- **kwargs – Keyword arguments to be passed to the model. They overwrite the ones in the model package.
Returns: nn.Module corresponding to the pretrained model conf/URL.
Raises: ValueError if the input config file doesn’t contain the keys – model_name, model_args or state_dict.
-
numpy_separate
(wav: <sphinx.ext.autodoc.importer._MockObject object at 0x7f647e24be10>, **kwargs) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f647e24be48>[source]¶ Numpy interface to separate.
-
separate
(wav, output_dir=None, force_overwrite=False, **kwargs)[source]¶ Infer separated sources from input waveforms. Also supports filenames.
Parameters: - wav (Union[torch.Tensor, numpy.ndarray, str]) – waveform array/tensor. Shape: 1D, 2D or 3D tensor, time last.
- output_dir (str) – path to save all the wav files. If None, estimated sources will be saved next to the original ones.
- force_overwrite (bool) – whether to overwrite existing files.
- **kwargs – keyword arguments to be passed to _separate.
Returns: - Union[torch.Tensor, numpy.ndarray, None], the estimated sources.
(batch, n_src, time) or (n_src, time) w/o batch dim.
Note
By default, separate calls _separate which calls forward. For models whose forward doesn’t return waveform tensors, overwrite _separate to return waveform tensors.
-
-
class
asteroid.models.base_models.
BaseTasNet
(encoder, masker, decoder, encoder_activation=None)[source]¶ Bases:
asteroid.models.base_models.BaseModel
Base class for encoder-masker-decoder separation models.
Parameters: -
forward
(wav)[source]¶ Enc/Mask/Dec model forward
Parameters: wav (torch.Tensor) – waveform tensor. 1D, 2D or 3D tensor, time last. Returns: torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
-