import numpy as np
import torch
from torch import tensor
import torch.nn as nn
from scipy.io import loadmat
import pathlib
import os
[docs]class SingleSrcPMSQE(nn.Module):
""" Computes the Perceptual Metric for Speech Quality Evaluation (PMSQE)
as described in [1].
This version is only designed for 16 kHz (512 length DFT).
Adaptation to 8 kHz could be done by changing the parameters of the
class (see Tensorflow implementation).
The SLL, frequency and gain equalization are applied in each
sequence independently.
Parameters:
window_name (str): Select the used window function for the correct
factor to be applied. Defaults to sqrt hanning window.
Among ['rect', 'hann', 'sqrt_hann', 'hamming', 'flatTop'].
window_weight (float, optional): Correction to the window factor
applied.
bark_eq (bool, optional): Whether to apply bark equalization.
gain_eq (bool, optional): Whether to apply gain equalization.
sample_rate (int): Sample rate of the input audio.
References:
[1] J.M.Martin, A.M.Gomez, J.A.Gonzalez, A.M.Peinado 'A Deep Learning
Loss Function based on the Perceptual Evaluation of the
Speech Quality', IEEE Signal Processing Letters, 2018.
Implemented by Juan M. Martin. Contact: mdjuamart@ugr.es
Copyright 2019: University of Granada, Signal Processing, Multimedia
Transmission and Speech/Audio Technologies (SigMAT) Group.
.. note:: Inspired on the Perceptual Evaluation of the Speech Quality (PESQ)
algorithm, this function consists of two regularization factors :
the symmetrical and asymmetrical distortion in the loudness domain.
Examples:
>>> import torch
>>> from asteroid.filterbanks import STFTFB, Encoder, transforms
>>> from asteroid.losses import PITLossWrapper, SingleSrcPMSQE
>>> stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
>>> # Usage by itself
>>> ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000)
>>> ref_spec = transforms.take_mag(stft(ref))
>>> est_spec = transforms.take_mag(stft(est))
>>> loss_func = SingleSrcPMSQE()
>>> loss_value = loss_func(est_spec, ref_spec)
>>> # Usage with PITLossWrapper
>>> loss_func = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt')
>>> ref, est = torch.randn(2, 3, 16000), torch.randn(2, 3, 16000)
>>> ref_spec = transforms.take_mag(stft(ref))
>>> est_spec = transforms.take_mag(stft(est))
>>> loss_value = loss_func(ref_spec, est_spec)
"""
def __init__(
self,
window_name="sqrt_hann",
window_weight=1.0,
bark_eq=True,
gain_eq=True,
sample_rate=16000,
):
super().__init__()
self.window_name = window_name
self.window_weight = window_weight
self.bark_eq = bark_eq
self.gain_eq = gain_eq
if sample_rate not in [16000, 8000]:
raise ValueError("Unsupported sample rate {}".format(sample_rate))
self.sample_rate = sample_rate
if sample_rate == 16000:
self.Sp = 6.910853e-006
self.Sl = 1.866055e-001
self.nbins = 512
self.nbark = 49
else:
self.Sp = 2.764344e-5
self.Sl = 1.866055e-1
self.nbins = 256
self.nbark = 42
# As described in [1] and used in the TF implementation.
self.alpha = 0.1
self.beta = 0.309 * self.alpha
pow_correc_factor = self.get_correction_factor(window_name)
self.pow_correc_factor = pow_correc_factor * self.window_weight
# Initialize to None and populate as a function of sample rate.
self.abs_thresh_power = None
self.modified_zwicker_power = None
self.width_of_band_bark = None
self.bark_matrix = None
self.mask_sll = None
self.populate_constants(self.sample_rate)
self.sqrt_total_width = torch.sqrt(torch.sum(self.width_of_band_bark))
self.EPS = 1e-8
[docs] def forward(self, est_targets, targets, pad_mask=None):
"""
Args
est_targets (torch.Tensor): Dimensions (B, T, F).
Padded degraded power spectrum in time-frequency domain.
targets (torch.Tensor): Dimensions (B, T, F).
Zero-Padded reference power spectrum in time-frequency domain.
pad_mask (torch.Tensor, optional): Dimensions (B, T, 1). Mask
to indicate the padding frames. Defaults to all ones.
Dimensions
B: Number of sequences in the batch.
T: Number of time frames.
F: Number of frequency bins.
Returns
torch.tensor of shape (B, ), wD + 0.309 * wDA
Notes
Dimensions (B, F, T) are also supported by SingleSrcPMSQE but are
less efficient because input tensors are transposed (not inplace).
Examples
"""
assert est_targets.shape == targets.shape
# Need transpose? Find it out
try:
freq_idx = est_targets.shape.index(self.nbins // 2 + 1)
except ValueError:
raise ValueError(
"Could not find dimension with {} elements in "
"input tensors, verify your inputs"
"".format(self.nbins // 2 + 1)
)
if freq_idx == 1:
est_targets = est_targets.transpose(1, 2)
targets = targets.transpose(1, 2)
if pad_mask is not None:
# Transpose the pad mask as well if needed.
pad_mask = pad_mask.transpose(1, 2) if freq_idx == 1 else pad_mask
else:
# Suppose no padding if no pad_mask is provided.
pad_mask = torch.ones(
est_targets.shape[0], est_targets.shape[1], 1, device=est_targets.device
)
# SLL equalization
ref_spectra = self.magnitude_at_sll(targets, pad_mask)
deg_spectra = self.magnitude_at_sll(est_targets, pad_mask)
# Bark spectra computation
ref_bark_spectra = self.bark_computation(ref_spectra)
deg_bark_spectra = self.bark_computation(deg_spectra)
# (Optional) frequency and gain equalization
if self.bark_eq:
deg_bark_spectra = self.bark_freq_equalization(ref_bark_spectra, deg_bark_spectra)
if self.gain_eq:
deg_bark_spectra = self.bark_gain_equalization(ref_bark_spectra, deg_bark_spectra)
# Distortion matrix computation
sym_d, asym_d = self.compute_distortion_tensors(ref_bark_spectra, deg_bark_spectra)
# Per-frame distortion
audible_power_ref = self.compute_audible_power(ref_bark_spectra, 1.0)
wd_frame, wda_frame = self.per_frame_distortion(sym_d, asym_d, audible_power_ref)
# Mean distortions over frames : keep batch dims
dims = [-1, -2]
pmsqe_frame = (self.alpha * wd_frame + self.beta * wda_frame) * pad_mask
pmsqe = torch.sum(pmsqe_frame, dim=dims) / pad_mask.sum(dims)
return pmsqe
[docs] def magnitude_at_sll(self, spectra, pad_mask):
# Apply padding and SLL masking
masked_spectra = spectra * pad_mask * self.mask_sll
# Compute mean over frequency
freq_mean_masked_spectra = torch.mean(masked_spectra, dim=-1, keepdim=True)
# Compute mean over time (taking into account padding)
sum_spectra = torch.sum(freq_mean_masked_spectra, dim=-2, keepdim=True)
seq_len = torch.sum(pad_mask, dim=-2, keepdim=True)
mean_pow = sum_spectra / seq_len
# Compute final SLL spectra
return 10000000.0 * spectra / mean_pow
[docs] def bark_computation(self, spectra):
return self.Sp * torch.matmul(spectra, self.bark_matrix)
[docs] def compute_audible_power(self, bark_spectra, factor=1.0):
# Apply absolute hearing threshold to each band
thr_bark = torch.where(
bark_spectra > self.abs_thresh_power * factor,
bark_spectra,
torch.zeros_like(bark_spectra),
)
# Sum band power over frequency
return torch.sum(thr_bark, dim=-1, keepdim=True)
[docs] def bark_gain_equalization(self, ref_bark_spectra, deg_bark_spectra):
# Compute audible power
audible_power_ref = self.compute_audible_power(ref_bark_spectra, 1.0)
audible_power_deg = self.compute_audible_power(deg_bark_spectra, 1.0)
# Compute gain factor
gain = (audible_power_ref + 5.0e3) / (audible_power_deg + 5.0e3)
# Limit the range of the gain factor
limited_gain = torch.min(gain, 5.0 * torch.ones_like(gain))
limited_gain = torch.max(limited_gain, 3.0e-4 * torch.ones_like(limited_gain))
# Apply gain correction on degraded
return limited_gain * deg_bark_spectra
[docs] def bark_freq_equalization(self, ref_bark_spectra, deg_bark_spectra):
"""This version is applied in the degraded directly."""
# Identification of speech active frames
audible_power_x100 = self.compute_audible_power(ref_bark_spectra, 100.0)
not_silent = audible_power_x100 >= 1.0e7
# Threshold for active bark bins
cond_thr = ref_bark_spectra >= self.abs_thresh_power * 100.0
ref_thresholded = torch.where(
cond_thr, ref_bark_spectra, torch.zeros_like(ref_bark_spectra)
)
deg_thresholded = torch.where(
cond_thr, deg_bark_spectra, torch.zeros_like(deg_bark_spectra)
)
# Total power per bark bin (ppb)
avg_ppb_ref = torch.sum(
torch.where(not_silent, ref_thresholded, torch.zeros_like(ref_thresholded)),
dim=-2,
keepdim=True,
)
avg_ppb_deg = torch.sum(
torch.where(not_silent, deg_thresholded, torch.zeros_like(deg_thresholded)),
dim=-2,
keepdim=True,
)
# Compute equalizer
equalizer = (avg_ppb_ref + 1000.0) / (avg_ppb_deg + 1000.0)
equalizer = torch.min(equalizer, 100.0 * torch.ones_like(equalizer))
equalizer = torch.max(equalizer, 0.01 * torch.ones_like(equalizer))
# Apply frequency correction on degraded
return equalizer * deg_bark_spectra
[docs] def loudness_computation(self, bark_spectra):
# Bark spectra transformed to a sone loudness scale using Zwicker's law
aterm = torch.pow(self.abs_thresh_power / 0.5, self.modified_zwicker_power)
bterm = (
torch.pow(0.5 + 0.5 * bark_spectra / self.abs_thresh_power, self.modified_zwicker_power)
- 1.0
)
loudness_dens = self.Sl * aterm * bterm
cond = bark_spectra < self.abs_thresh_power
return torch.where(cond, torch.zeros_like(loudness_dens), loudness_dens)
[docs] def compute_distortion_tensors(self, ref_bark_spec, deg_bark_spec):
# After bark spectra are compensated, transform to sone loudness
original_loudness = self.loudness_computation(ref_bark_spec)
distorted_loudness = self.loudness_computation(deg_bark_spec)
# Loudness difference
r = torch.abs(distorted_loudness - original_loudness)
# Masking effect computation
m = 0.25 * torch.min(original_loudness, distorted_loudness)
# Center clipping using masking effect
sym_d = torch.max(r - m, torch.ones_like(r) * self.EPS)
# Asymmetry factor computation
asym = torch.pow((deg_bark_spec + 50.0) / (ref_bark_spec + 50.0), 1.2)
cond = asym < 3.0 * torch.ones_like(asym)
asym_factor = torch.where(
cond, torch.zeros_like(asym), torch.min(asym, 12.0 * torch.ones_like(asym))
)
# Asymmetric Disturbance matrix computation
asym_d = asym_factor * sym_d
return sym_d, asym_d
[docs] def per_frame_distortion(self, sym_d, asym_d, total_power_ref):
# Computation of the norms over bark bands for each frame
# 2 and 1 for sym_d and asym_d, respectively
d_frame = torch.sum(
torch.pow(sym_d * self.width_of_band_bark, 2.0) + self.EPS, dim=-1, keepdim=True
)
# a = torch.pow(sym_d * self.width_of_band_bark, 2.0)
# b = sym_d
# print(a.min(),a.max(),b.min(),b.max(), d_frame.min(), d_frame.max())
# print(self.width_of_band_bark.requires_grad)
# print(d_frame.requires_grad)
d_frame = torch.sqrt(d_frame) * self.sqrt_total_width
da_frame = torch.sum(asym_d * self.width_of_band_bark, dim=-1, keepdim=True)
# Weighting by the audible power raised to 0.04
weights = torch.pow((total_power_ref + 1e5) / 1e7, 0.04)
# Bounded computation of the per frame distortion metric
wd_frame = torch.min(d_frame / weights, 45.0 * torch.ones_like(d_frame))
wda_frame = torch.min(da_frame / weights, 45.0 * torch.ones_like(da_frame))
return wd_frame, wda_frame
[docs] @staticmethod
def get_correction_factor(window_name):
""" Returns the power correction factor depending on the window. """
if window_name == "rect":
return 1.0
elif window_name == "hann":
return 2.666666666666754
elif window_name == "sqrt_hann":
return 2.0
elif window_name == "hamming":
return 2.51635879188799
elif window_name == "flatTop":
return 5.70713295690759
else:
raise ValueError("Unexpected window type {}".format(window_name))
[docs] def populate_constants(self, sample_rate):
if sample_rate == 8000:
self.register_8k_constants()
elif sample_rate == 16000:
self.register_16k_constants()
# Mask SSL
mask_sll = np.zeros(shape=[self.nbins // 2 + 1], dtype=np.float32)
mask_sll[11] = 0.5 * 25.0 / 31.25
mask_sll[12:104] = 1.0
mask_sll[104] = 0.5
correction = self.pow_correc_factor * (self.nbins + 2.0) / self.nbins ** 2
mask_sll = mask_sll * correction
self.mask_sll = nn.Parameter(tensor(mask_sll), requires_grad=False)
[docs] def register_16k_constants(self):
# Absolute threshold power
abs_thresh_power = [
51286152.00,
2454709.500,
70794.593750,
4897.788574,
1174.897705,
389.045166,
104.712860,
45.708820,
17.782795,
9.772372,
4.897789,
3.090296,
1.905461,
1.258925,
0.977237,
0.724436,
0.562341,
0.457088,
0.389045,
0.331131,
0.295121,
0.269153,
0.257040,
0.251189,
0.251189,
0.251189,
0.251189,
0.263027,
0.288403,
0.309030,
0.338844,
0.371535,
0.398107,
0.436516,
0.467735,
0.489779,
0.501187,
0.501187,
0.512861,
0.524807,
0.524807,
0.524807,
0.512861,
0.478630,
0.426580,
0.371535,
0.363078,
0.416869,
0.537032,
]
self.abs_thresh_power = nn.Parameter(tensor(abs_thresh_power), requires_grad=False)
# Modified zwicker power
modif_zwicker_power = [
0.25520097857560436,
0.25520097857560436,
0.25520097857560436,
0.25520097857560436,
0.25168783742879913,
0.24806665731869609,
0.244767379124259,
0.24173800119368227,
0.23893798876066405,
0.23633516221479894,
0.23390360348392067,
0.23162209128929445,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
]
self.modified_zwicker_power = nn.Parameter(tensor(modif_zwicker_power), requires_grad=False)
# Width of band bark
width_of_band_bark = [
0.157344,
0.317994,
0.322441,
0.326934,
0.331474,
0.336061,
0.340697,
0.345381,
0.350114,
0.354897,
0.359729,
0.364611,
0.369544,
0.374529,
0.379565,
0.384653,
0.389794,
0.394989,
0.400236,
0.405538,
0.410894,
0.416306,
0.421773,
0.427297,
0.432877,
0.438514,
0.444209,
0.449962,
0.455774,
0.461645,
0.467577,
0.473569,
0.479621,
0.485736,
0.491912,
0.498151,
0.504454,
0.510819,
0.517250,
0.523745,
0.530308,
0.536934,
0.543629,
0.550390,
0.557220,
0.564119,
0.571085,
0.578125,
0.585232,
]
self.width_of_band_bark = nn.Parameter(tensor(width_of_band_bark), requires_grad=False)
# Bark matrix
local_path = pathlib.Path(__file__).parent.absolute()
bark_path = os.path.join(local_path, "bark_matrix_16k.mat")
bark_matrix = loadmat(bark_path)["Bark_matrix_16k"].astype("float32")
self.bark_matrix = nn.Parameter(tensor(bark_matrix), requires_grad=False)
[docs] def register_8k_constants(self):
# Absolute threshold power
abs_thresh_power = [
51286152,
2454709.500,
70794.593750,
4897.788574,
1174.897705,
389.045166,
104.712860,
45.708820,
17.782795,
9.772372,
4.897789,
3.090296,
1.905461,
1.258925,
0.977237,
0.724436,
0.562341,
0.457088,
0.389045,
0.331131,
0.295121,
0.269153,
0.257040,
0.251189,
0.251189,
0.251189,
0.251189,
0.263027,
0.288403,
0.309030,
0.338844,
0.371535,
0.398107,
0.436516,
0.467735,
0.489779,
0.501187,
0.501187,
0.512861,
0.524807,
0.524807,
0.524807,
]
self.abs_thresh_power = nn.Parameter(tensor(abs_thresh_power), requires_grad=False)
# Modified zwicker power
modif_zwicker_power = [
0.25520097857560436,
0.25520097857560436,
0.25520097857560436,
0.25520097857560436,
0.25168783742879913,
0.24806665731869609,
0.244767379124259,
0.24173800119368227,
0.23893798876066405,
0.23633516221479894,
0.23390360348392067,
0.23162209128929445,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
0.23,
]
self.modified_zwicker_power = nn.Parameter(tensor(modif_zwicker_power), requires_grad=False)
# Width of band bark
width_of_band_bark = [
0.157344,
0.317994,
0.322441,
0.326934,
0.331474,
0.336061,
0.340697,
0.345381,
0.350114,
0.354897,
0.359729,
0.364611,
0.369544,
0.374529,
0.379565,
0.384653,
0.389794,
0.394989,
0.400236,
0.405538,
0.410894,
0.416306,
0.421773,
0.427297,
0.432877,
0.438514,
0.444209,
0.449962,
0.455774,
0.461645,
0.467577,
0.473569,
0.479621,
0.485736,
0.491912,
0.498151,
0.504454,
0.510819,
0.517250,
0.523745,
0.530308,
0.536934,
]
self.width_of_band_bark = nn.Parameter(tensor(width_of_band_bark), requires_grad=False)
# Bark matrix
local_path = pathlib.Path(__file__).parent.absolute()
bark_path = os.path.join(local_path, "bark_matrix_8k.mat")
bark_matrix = loadmat(bark_path)["Bark_matrix_8k"].astype("float32")
self.bark_matrix = nn.Parameter(tensor(bark_matrix), requires_grad=False)