Shortcuts

asteroid.models.base_models module

class asteroid.models.base_models.BaseEncoderMaskerDecoder(encoder, masker, decoder, encoder_activation=None)[source]

Bases: asteroid.models.base_models.BaseModel

Base class for encoder-masker-decoder separation models.

Parameters:
  • encoder (Encoder) – Encoder instance.
  • masker (nn.Module) – masker network.
  • decoder (Decoder) – Decoder instance.
  • encoder_activation (Optional[str], optional) – Activation to apply after encoder. See asteroid.masknn.activations for valid values.
forward(wav)[source]

Enc/Mask/Dec model forward

Parameters:wav (torch.Tensor) – waveform tensor. 1D, 2D or 3D tensor, time last.
Returns:torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
get_model_args()[source]

Arguments needed to re-instantiate the model.

postprocess_decoded(decoded)[source]

Hook to perform transformations on the decoded, time domain representation (output of the decoder) before original shape reconstruction.

Parameters:decoded (Tensor of shape (batch, n_src, time)) – Output of the decoder, before original shape reconstruction.
Returns:Transformed decoded
postprocess_encoded(tf_rep)[source]

Hook to perform transformations on the encoded, time-frequency domain representation (output of the encoder) before encoder activation is applied.

Parameters:tf_rep (Tensor of shape (batch, freq, time)) – Output of the encoder, before encoder activation is applied.
Returns:Transformed tf_rep
postprocess_masked(masked_tf_rep)[source]

Hook to perform transformations on the masked time-frequency domain representation (result of masking in the time-frequency domain) before decoding.

Parameters:masked_tf_rep (Tensor of shape (batch, n_src, freq, time)) – Masked time-frequency representation, before decoding.
Returns:Transformed masked_tf_rep
postprocess_masks(masks)[source]

Hook to perform transformations on the masks (output of the masker) before masks are applied.

Parameters:masks (Tensor of shape (batch, n_src, freq, time)) – Output of the masker
Returns:Transformed masks
class asteroid.models.base_models.BaseModel[source]

Bases: sphinx.ext.autodoc.importer._MockObject

file_separate(filename: str, output_dir=None, force_overwrite=False, **kwargs) → None[source]

Filename interface to separate.

forward(*args, **kwargs)[source]
classmethod from_pretrained(pretrained_model_conf_or_path, *args, **kwargs)[source]

Instantiate separation model from a model config (file or dict).

Parameters:
  • pretrained_model_conf_or_path (Union[dict, str]) – model conf as returned by serialize, or path to it. Need to contain model_args and state_dict keys.
  • *args – Positional arguments to be passed to the model.
  • **kwargs – Keyword arguments to be passed to the model. They overwrite the ones in the model package.
Returns:

nn.Module corresponding to the pretrained model conf/URL.

Raises:

ValueError if the input config file doesn’t contain the keysmodel_name, model_args or state_dict.

get_model_args()[source]
get_state_dict()[source]

In case the state dict needs to be modified before sharing the model.

numpy_separate(wav: <sphinx.ext.autodoc.importer._MockObject object at 0x7fbe91e8a7b8>, **kwargs) → <sphinx.ext.autodoc.importer._MockObject object at 0x7fbe91e8a7f0>[source]

Numpy interface to separate.

separate(wav, output_dir=None, force_overwrite=False, **kwargs)[source]

Infer separated sources from input waveforms. Also supports filenames.

Parameters:
  • wav (Union[torch.Tensor, numpy.ndarray, str]) – waveform array/tensor. Shape: 1D, 2D or 3D tensor, time last.
  • output_dir (str) – path to save all the wav files. If None, estimated sources will be saved next to the original ones.
  • force_overwrite (bool) – whether to overwrite existing files.
  • **kwargs – keyword arguments to be passed to _separate.
Returns:

Union[torch.Tensor, numpy.ndarray, None], the estimated sources.

(batch, n_src, time) or (n_src, time) w/o batch dim.

Note

By default, separate calls _separate which calls forward. For models whose forward doesn’t return waveform tensors, overwrite _separate to return waveform tensors.

serialize()[source]

Serialize model and output dictionary.

Returns:dict, serialized model with keys model_args and state_dict.
torch_separate(wav: <sphinx.ext.autodoc.importer._MockObject object at 0x7fbe91e83c88>, **kwargs) → <sphinx.ext.autodoc.importer._MockObject object at 0x7fbe91e83cc0>[source]

Core logic of separate.

asteroid.models.base_models.BaseTasNet

alias of asteroid.models.base_models.BaseEncoderMaskerDecoder

Read the Docs v: v0.3.4
Versions
latest
stable
v0.3.4
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.