Shortcuts

asteroid.models.base_models module

class asteroid.models.base_models.BaseModel[source]

Bases: sphinx.ext.autodoc.importer._MockObject

file_separate(filename: str, output_dir=None, force_overwrite=False, **kwargs) → None[source]

Filename interface to separate.

forward(*args, **kwargs)[source]
classmethod from_pretrained(pretrained_model_conf_or_path, *args, **kwargs)[source]

Instantiate separation model from a model config (file or dict).

Parameters:
  • pretrained_model_conf_or_path (Union[dict, str]) – model conf as returned by serialize, or path to it. Need to contain model_args and state_dict keys.
  • *args – Positional arguments to be passed to the model.
  • **kwargs – Keyword arguments to be passed to the model. They overwrite the ones in the model package.
Returns:

nn.Module corresponding to the pretrained model conf/URL.

Raises:

ValueError if the input config file doesn’t contain the keysmodel_name, model_args or state_dict.

get_model_args()[source]
get_state_dict()[source]

In case the state dict needs to be modified before sharing the model.

numpy_separate(wav: <sphinx.ext.autodoc.importer._MockObject object at 0x7f647e24be10>, **kwargs) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f647e24be48>[source]

Numpy interface to separate.

separate(wav, output_dir=None, force_overwrite=False, **kwargs)[source]

Infer separated sources from input waveforms. Also supports filenames.

Parameters:
  • wav (Union[torch.Tensor, numpy.ndarray, str]) – waveform array/tensor. Shape: 1D, 2D or 3D tensor, time last.
  • output_dir (str) – path to save all the wav files. If None, estimated sources will be saved next to the original ones.
  • force_overwrite (bool) – whether to overwrite existing files.
  • **kwargs – keyword arguments to be passed to _separate.
Returns:

Union[torch.Tensor, numpy.ndarray, None], the estimated sources.

(batch, n_src, time) or (n_src, time) w/o batch dim.

Note

By default, separate calls _separate which calls forward. For models whose forward doesn’t return waveform tensors, overwrite _separate to return waveform tensors.

serialize()[source]

Serialize model and output dictionary.

Returns:dict, serialized model with keys model_args and state_dict.
torch_separate(wav: <sphinx.ext.autodoc.importer._MockObject object at 0x7f647e24b128>, **kwargs) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f647e24b470>[source]

Core logic of separate.

class asteroid.models.base_models.BaseTasNet(encoder, masker, decoder, encoder_activation=None)[source]

Bases: asteroid.models.base_models.BaseModel

Base class for encoder-masker-decoder separation models.

Parameters:
  • encoder (Encoder) – Encoder instance.
  • masker (nn.Module) – masker network.
  • decoder (Decoder) – Decoder instance.
forward(wav)[source]

Enc/Mask/Dec model forward

Parameters:wav (torch.Tensor) – waveform tensor. 1D, 2D or 3D tensor, time last.
Returns:torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
get_model_args()[source]

Arguments needed to re-instantiate the model.

Read the Docs v: v0.3.3
Versions
latest
stable
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.