asteroid.models.dcunet module¶
-
class
asteroid.models.dcunet.
BaseDCUNet
(architecture, stft_n_filters=1024, stft_kernel_size=1024, stft_stride=256, sample_rate=16000.0, **masknet_kwargs)[source]¶ Bases:
asteroid.models.base_models.BaseEncoderMaskerDecoder
Base class for
DCUNet
andDCCRNet
classes.Parameters: - architecture (str) – The architecture to use. Overriden by subclasses.
- stft_n_filters (int) –
- stft_kernel_size (int) – STFT frame length to use.
- stft_stride (int, optional) – STFT hop length to use.
- sample_rate (float) – Sampling rate of the model.
- masknet_kwargs (optional) – Passed to the masknet constructor.
-
forward_encoder
(wav)[source]¶ Computes time-frequency representation of wav.
Parameters: wav (torch.Tensor) – waveform tensor in 3D shape, time last. Returns: torch.Tensor, of shape (batch, feat, seq).
-
apply_masks
(tf_rep, est_masks)[source]¶ Applies masks to time-frequency representation.
Parameters: - tf_rep (torch.Tensor) – Time-frequency representation in (batch, feat, seq) shape.
- est_masks (torch.Tensor) – Estimated masks.
Returns: torch.Tensor – Masked time-frequency representations.
-
class
asteroid.models.dcunet.
DCUNet
(architecture, stft_n_filters=1024, stft_kernel_size=1024, stft_stride=256, sample_rate=16000.0, **masknet_kwargs)[source]¶ Bases:
asteroid.models.dcunet.BaseDCUNet
DCUNet as proposed in [1].
Parameters: - architecture (str) – The architecture to use, any of “DCUNet-10”, “DCUNet-16”, “DCUNet-20”, “Large-DCUNet-20”.
- stft_n_filters (int) –
- stft_kernel_size (int) – STFT frame length to use.
- stft_stride (int, optional) – STFT hop length to use.
- sample_rate (float) – Sampling rate of the model.
- masknet_kwargs (optional) – Passed to
DCUMaskNet
- References
- [1] : “Phase-aware Speech Enhancement with Deep Complex U-Net”, Hyeong-Seok Choi et al. https://arxiv.org/abs/1903.03107