asteroid.models.base_models module¶
-
class
asteroid.models.base_models.
BaseModel
(sample_rate: float = None, in_channels: Optional[int] = 1)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Base class for serializable models.
Defines saving/loading procedures, and separation interface to separate. Need to overwrite the forward and get_model_args methods.
Models inheriting from BaseModel can be used by
asteroid.separate
and by the asteroid-infer CLI. For models whose forward doesn’t go from waveform to waveform tensors, overwrite forward_wav to return waveform tensors.Parameters: - sample_rate (float) – Operating sample rate of the model.
- in_channels – Number of input channels in the signal. If None, no checks will be performed.
-
separate
(*args, **kwargs)[source]¶ Convenience for
separate()
.
-
torch_separate
(*args, **kwargs)[source]¶ Convenience for
torch_separate()
.
-
numpy_separate
(*args, **kwargs)[source]¶ Convenience for
numpy_separate()
.
-
file_separate
(*args, **kwargs)[source]¶ Convenience for
file_separate()
.
-
forward_wav
(wav, *args, **kwargs)[source]¶ Separation method for waveforms.
In case the network’s forward doesn’t have waveforms as input/output, overwrite this method to separate from waveform to waveform. Should return a single torch.Tensor, the separated waveforms.
Parameters: wav (torch.Tensor) – waveform array/tensor. Shape: 1D, 2D or 3D tensor, time last.
-
classmethod
from_pretrained
(pretrained_model_conf_or_path, *args, **kwargs)[source]¶ Instantiate separation model from a model config (file or dict).
Parameters: - pretrained_model_conf_or_path (Union[dict, str]) – model conf as returned by serialize, or path to it. Need to contain model_args and state_dict keys.
- *args – Positional arguments to be passed to the model.
- **kwargs – Keyword arguments to be passed to the model. They overwrite the ones in the model package.
Returns: nn.Module corresponding to the pretrained model conf/URL.
Raises: ValueError if the input config file doesn’t contain the keys – model_name, model_args or state_dict.
-
class
asteroid.models.base_models.
BaseEncoderMaskerDecoder
(encoder, masker, decoder, encoder_activation=None)[source]¶ Bases:
asteroid.models.base_models.BaseModel
Base class for encoder-masker-decoder separation models.
Parameters: -
forward
(wav)[source]¶ Enc/Mask/Dec model forward
Parameters: wav (torch.Tensor) – waveform tensor. 1D, 2D or 3D tensor, time last. Returns: torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
-
forward_encoder
(wav: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9a2f490>) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9a2f1d0>[source]¶ Computes time-frequency representation of wav.
Parameters: wav (torch.Tensor) – waveform tensor in 3D shape, time last. Returns: torch.Tensor, of shape (batch, feat, seq).
-
forward_masker
(tf_rep: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9a2f5d0>) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b2c4d0>[source]¶ Estimates masks from time-frequency representation.
Parameters: tf_rep (torch.Tensor) – Time-frequency representation in (batch, feat, seq). Returns: torch.Tensor – Estimated masks
-
apply_masks
(tf_rep: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b308d0>, est_masks: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b30910>) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b30cd0>[source]¶ Applies masks to time-frequency representation.
Parameters: - tf_rep (torch.Tensor) – Time-frequency representation in (batch, feat, seq) shape.
- est_masks (torch.Tensor) – Estimated masks.
Returns: torch.Tensor – Masked time-frequency representations.
-
forward_decoder
(masked_tf_rep: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b30a90>) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b30a50>[source]¶ Reconstructs time-domain waveforms from masked representations.
Parameters: masked_tf_rep (torch.Tensor) – Masked time-frequency representation. Returns: torch.Tensor – Time-domain waveforms.
-
-
asteroid.models.base_models.
BaseTasNet
[source]¶ alias of
asteroid.models.base_models.BaseEncoderMaskerDecoder