Shortcuts

asteroid.models.base_models module

class asteroid.models.base_models.BaseModel(sample_rate: float = None, in_channels: Optional[int] = 1)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Base class for serializable models.

Defines saving/loading procedures, and separation interface to separate. Need to overwrite the forward and get_model_args methods.

Models inheriting from BaseModel can be used by asteroid.separate and by the asteroid-infer CLI. For models whose forward doesn’t go from waveform to waveform tensors, overwrite forward_wav to return waveform tensors.

Parameters:
  • sample_rate (float) – Operating sample rate of the model.
  • in_channels – Number of input channels in the signal. If None, no checks will be performed.
forward(*args, **kwargs)[source]
sample_rate[source]

Operating sample rate of the model (float).

separate(*args, **kwargs)[source]

Convenience for separate().

torch_separate(*args, **kwargs)[source]

Convenience for torch_separate().

numpy_separate(*args, **kwargs)[source]

Convenience for numpy_separate().

file_separate(*args, **kwargs)[source]

Convenience for file_separate().

forward_wav(wav, *args, **kwargs)[source]

Separation method for waveforms.

In case the network’s forward doesn’t have waveforms as input/output, overwrite this method to separate from waveform to waveform. Should return a single torch.Tensor, the separated waveforms.

Parameters:wav (torch.Tensor) – waveform array/tensor. Shape: 1D, 2D or 3D tensor, time last.
classmethod from_pretrained(pretrained_model_conf_or_path, *args, **kwargs)[source]

Instantiate separation model from a model config (file or dict).

Parameters:
  • pretrained_model_conf_or_path (Union[dict, str]) – model conf as returned by serialize, or path to it. Need to contain model_args and state_dict keys.
  • *args – Positional arguments to be passed to the model.
  • **kwargs – Keyword arguments to be passed to the model. They overwrite the ones in the model package.
Returns:

nn.Module corresponding to the pretrained model conf/URL.

Raises:

ValueError if the input config file doesn’t contain the keysmodel_name, model_args or state_dict.

serialize()[source]

Serialize model and output dictionary.

Returns:dict, serialized model with keys model_args and state_dict.
get_state_dict()[source]

In case the state dict needs to be modified before sharing the model.

get_model_args()[source]

Should return args to re-instantiate the class.

class asteroid.models.base_models.BaseEncoderMaskerDecoder(encoder, masker, decoder, encoder_activation=None)[source]

Bases: asteroid.models.base_models.BaseModel

Base class for encoder-masker-decoder separation models.

Parameters:
  • encoder (Encoder) – Encoder instance.
  • masker (nn.Module) – masker network.
  • decoder (Decoder) – Decoder instance.
  • encoder_activation (Optional[str], optional) – Activation to apply after encoder. See asteroid.masknn.activations for valid values.
forward(wav)[source]

Enc/Mask/Dec model forward

Parameters:wav (torch.Tensor) – waveform tensor. 1D, 2D or 3D tensor, time last.
Returns:torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
forward_encoder(wav: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9a2f490>) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9a2f1d0>[source]

Computes time-frequency representation of wav.

Parameters:wav (torch.Tensor) – waveform tensor in 3D shape, time last.
Returns:torch.Tensor, of shape (batch, feat, seq).
forward_masker(tf_rep: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9a2f5d0>) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b2c4d0>[source]

Estimates masks from time-frequency representation.

Parameters:tf_rep (torch.Tensor) – Time-frequency representation in (batch, feat, seq).
Returns:torch.Tensor – Estimated masks
apply_masks(tf_rep: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b308d0>, est_masks: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b30910>) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b30cd0>[source]

Applies masks to time-frequency representation.

Parameters:
  • tf_rep (torch.Tensor) – Time-frequency representation in (batch, feat, seq) shape.
  • est_masks (torch.Tensor) – Estimated masks.
Returns:

torch.Tensor – Masked time-frequency representations.

forward_decoder(masked_tf_rep: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b30a90>) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b30a50>[source]

Reconstructs time-domain waveforms from masked representations.

Parameters:masked_tf_rep (torch.Tensor) – Masked time-frequency representation.
Returns:torch.Tensor – Time-domain waveforms.
get_model_args()[source]

Arguments needed to re-instantiate the model.

asteroid.models.base_models.BaseTasNet[source]

alias of asteroid.models.base_models.BaseEncoderMaskerDecoder

Read the Docs v: v0.4.4
Versions
latest
stable
v0.4.4
v0.4.3
v0.4.2
v0.4.1
v0.4.0
v0.3.5_b
v0.3.4
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.