asteroid.metrics module¶
-
asteroid.metrics.
get_metrics
(mix, clean, estimate, sample_rate=16000, metrics_list='all', average=True, compute_permutation=False, ignore_metrics_errors=False, filename=None)[source]¶ Get speech separation/enhancement metrics from mix/clean/estimate.
Parameters: - mix (np.array) – mixture array.
- clean (np.array) – reference array.
- estimate (np.array) – estimate array.
- sample_rate (int) – sampling rate of the audio clips.
- metrics_list (Union[List[str], str) – List of metrics to compute. Defaults to ‘all’ ([‘si_sdr’, ‘sdr’, ‘sir’, ‘sar’, ‘stoi’, ‘pesq’]).
- average (bool) – Return dict([float]) if True, else dict([array]).
- compute_permutation (bool) – Whether to compute the permutation on estimate sources for the output metrics (default False)
- ignore_metrics_errors (bool) – Whether to ignore errors that occur in computing the metrics. A warning will be printed instead.
- filename (str, optional) – If computing a metric fails, print this filename along with the exception/warning message for debugging purposes.
- Shape:
- mix: \((D, N)\) or (N, ).
- clean: \((K\_source, N)\) or (N, ).
- estimate: \((K\_target, N)\) or (N, ).
Returns: dict – Dictionary with all requested metrics, with ‘input_’ prefix for metrics at the input (mixture against clean), no prefix at the output (estimate against clean). Output format depends on average. - Examples
>>> import numpy as np >>> import pprint >>> from asteroid.metrics import get_metrics >>> mix = np.random.randn(1, 16000) >>> clean = np.random.randn(2, 16000) >>> est = np.random.randn(2, 16000) >>> metrics_dict = get_metrics(mix, clean, est, sample_rate=8000, ... metrics_list='all') >>> pprint.pprint(metrics_dict) {'input_pesq': 1.924380898475647, 'input_sar': -11.67667585294225, 'input_sdr': -14.88667106190552, 'input_si_sdr': -52.43849784881705, 'input_sir': -0.10419427290163795, 'input_stoi': 0.015112115177091223, 'pesq': 1.7713886499404907, 'sar': -11.610963379923195, 'sdr': -14.527246041125844, 'si_sdr': -46.26557128489802, 'sir': 0.4799929272243427, 'stoi': 0.022023073540350643}
-
class
asteroid.metrics.
MetricTracker
(sample_rate, metrics_list=('si_sdr', 'sdr', 'sir', 'sar', 'stoi', 'pesq'), average=True, compute_permutation=False, ignore_metrics_errors=False)[source]¶ Bases:
object
Metric tracker, subject to change.
Parameters: - sample_rate (int) – sampling rate of the audio clips.
- metrics_list (Union[List[str], str) – List of metrics to compute. Defaults to ‘all’ ([‘si_sdr’, ‘sdr’, ‘sir’, ‘sar’, ‘stoi’, ‘pesq’]).
- average (bool) – Return dict([float]) if True, else dict([array]).
- compute_permutation (bool) – Whether to compute the permutation on estimate sources for the output metrics (default False)
- ignore_metrics_errors (bool) – Whether to ignore errors that occur in computing the metrics. A warning will be printed instead.
-
__call__
(*, mix: numpy.ndarray, clean: numpy.ndarray, estimate: numpy.ndarray, filename=None, **kwargs)[source]¶ Compute metrics for mix/clean/estimate and log it to the class.
Parameters: - mix (np.array) – mixture array.
- clean (np.array) – reference array.
- estimate (np.array) – estimate array.
- sample_rate (int) – sampling rate of the audio clips.
- filename (str, optional) – If computing a metric fails, print this filename along with the exception/warning message for debugging purposes.
- **kwargs – Any key, value pair to log in the utterance metric (filename, speaker ID, etc…)
-
class
asteroid.metrics.
WERTracker
(model_name, trans_df)[source]¶ Bases:
object
Word Error Rate Tracker. Subject to change.
Parameters: - model_name (str) – Name of the petrained model to use.
- trans_df (dataframe) – Containing field utt_id and text. See librimix/ConvTasNet recipe.