Shortcuts

asteroid.models.demask module

class asteroid.models.demask.DeMask(input_type='mag', output_type='mag', hidden_dims=(1024, ), dropout=0.0, activation='relu', mask_act='relu', norm_type='gLN', fb_name='stft', n_filters=512, stride=256, kernel_size=512, sample_rate=16000, **fb_kwargs)[source]

Bases: asteroid.models.base_models.BaseEncoderMaskerDecoder

Simple MLP model for surgical mask speech enhancement A transformed-domain masking approach is used.

Parameters:
  • input_type (str, optional) – whether the magnitude spectrogram “mag” or both real imaginary parts “reim” are passed as features to the masker network. Concatenation of “mag” and “reim” also can be used by using “cat”.
  • output_type (str, optional) – whether the masker ouputs a mask for magnitude spectrogram “mag” or both real imaginary parts “reim”.
  • hidden_dims (list, optional) – list of MLP hidden layer sizes.
  • dropout (float, optional) – dropout probability.
  • activation (str, optional) – type of activation used in hidden MLP layers.
  • mask_act (str, optional) – Which non-linear function to generate mask.
  • norm_type (str, optional) – To choose from 'BN', 'gLN', 'cLN'.
  • fb_name (str) – type of analysis and synthesis filterbanks used, choose between [“stft”, “free”, “analytic_free”].
  • n_filters (int) – number of filters in the analysis and synthesis filterbanks.
  • stride (int) – filterbank filters stride.
  • kernel_size (int) – length of filters in the filterbank.
  • encoder_activation (str) –
  • sample_rate (float) – Sampling rate of the model.
  • **fb_kwargs (dict) – Additional kwards to pass to the filterbank creation.
forward_masker(tf_rep)[source]

Estimates masks based on time-frequency representations.

Parameters:tf_rep (torch.Tensor) – Time-frequency representation in (batch, freq, seq).
Returns:torch.Tensor – Estimated masks in (batch, freq, seq).
apply_masks(tf_rep, est_masks)[source]

Applies masks to time-frequency representations.

Parameters:
  • tf_rep (torch.Tensor) – Time-frequency representations in (batch, freq, seq).
  • est_masks (torch.Tensor) – Estimated masks in (batch, freq, seq).
Returns:

torch.Tensor – Masked time-frequency representations.

get_model_args()[source]

Arguments needed to re-instantiate the model.

asteroid.models.demask.build_demask_masker(n_in, n_out, activation='relu', dropout=0.0, hidden_dims=(1024, ), mask_act='relu', norm_type='gLN')[source]
Read the Docs v: v0.4.4
Versions
latest
stable
v0.4.4
v0.4.3
v0.4.2
v0.4.1
v0.4.0
v0.3.5_b
v0.3.4
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.