asteroid.models.base_models module

class asteroid.models.base_models.BaseTasNet(encoder, masker, decoder, encoder_activation=None)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Base class for encoder-masker-decoder separation models.

Parameters:
  • encoder (Encoder) – Encoder instance.
  • masker (nn.Module) – masker network.
  • decoder (Decoder) – Decoder instance.
forward(wav)[source]

Enc/Mask/Dec model forward

Parameters:wav (torch.Tensor) – waveform tensor. 1D, 2D or 3D tensor, time last.
Returns:torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
classmethod from_pretrained(pretrained_model_conf_or_path, *args, **kwargs)[source]

Instantiate separation model from a model config (file or dict).

Parameters:pretrained_model_conf_or_path (Union[dict, str]) – model conf as returned by serialize, or path to it. Need to contain model_args and state_dict keys.
Returns:Instance of BaseTasNet
Raises:ValueError if the input config file doesn’t contain the keysmodel_args and state_dict.
separate(wav)[source]

Infer separated sources from input waveforms. Also supports filenames.

Parameters:wav (Union[torch.Tensor, numpy.ndarray, str]) – waveform array/tensor. Shape: 1D, 2D or 3D tensor, time last.
Returns:
Union[torch.Tensor, numpy.ndarray, None], the estimated sources.
(batch, n_src, time) or (n_src, time) w/o batch dim.
serialize()[source]

Serialize model and output dictionary.

Returns:dict, serialized model with keys model_args and state_dict.