
Source code for asteroid.masknn.norms

from functools import partial
import torch
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
from .. import complex_nn

EPS = 1e-8

class _LayerNorm(nn.Module):
    """Layer Normalization base class."""

    def __init__(self, channel_size):
        super(_LayerNorm, self).__init__()
        self.channel_size = channel_size
        self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
        self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True)

    def apply_gain_and_bias(self, normed_x):
        """ Assumes input of size `[batch, chanel, *]`. """
        return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1)

[docs]class GlobLN(_LayerNorm): """Global Layer Normalization (globLN)."""
[docs] def forward(self, x): """Applies forward pass. Works for any input size > 2D. Args: x (:class:`torch.Tensor`): Shape `[batch, chan, *]` Returns: :class:`torch.Tensor`: gLN_x `[batch, chan, *]` """ dims = list(range(1, len(x.shape))) mean = x.mean(dim=dims, keepdim=True) var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True) return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
[docs]class ChanLN(_LayerNorm): """Channel-wise Layer Normalization (chanLN)."""
[docs] def forward(self, x): """Applies forward pass. Works for any input size > 2D. Args: x (:class:`torch.Tensor`): `[batch, chan, *]` Returns: :class:`torch.Tensor`: chanLN_x `[batch, chan, *]` """ mean = torch.mean(x, dim=1, keepdim=True) var = torch.var(x, dim=1, keepdim=True, unbiased=False) return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
[docs]class CumLN(_LayerNorm): """Cumulative Global layer normalization(cumLN)."""
[docs] def forward(self, x): """ Args: x (:class:`torch.Tensor`): Shape `[batch, channels, length]` Returns: :class:`torch.Tensor`: cumLN_x `[batch, channels, length]` """ batch, chan, spec_len = x.size() cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=-1) cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1) cnt = torch.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype).view( 1, 1, -1 ) cum_mean = cum_sum / cnt cum_var = cum_pow_sum - cum_mean.pow(2) return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt())
[docs]class FeatsGlobLN(_LayerNorm): """feature-wise global Layer Normalization (FeatsGlobLN). Applies normalization over frames for each channel."""
[docs] def forward(self, x): """Applies forward pass. Works for any input size > 2D. Args: x (:class:`torch.Tensor`): `[batch, chan, time]` Returns: :class:`torch.Tensor`: chanLN_x `[batch, chan, time]` """ stop = len(x.size()) dims = list(range(2, stop)) mean = torch.mean(x, dim=dims, keepdim=True) var = torch.var(x, dim=dims, keepdim=True, unbiased=False) return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
[docs]class BatchNorm(_BatchNorm): """Wrapper class for pytorch BatchNorm1D and BatchNorm2D""" def _check_input_dim(self, input): if input.dim() < 2 or input.dim() > 4: raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
# Aliases. gLN = GlobLN fgLN = FeatsGlobLN cLN = ChanLN cgLN = CumLN bN = BatchNorm
[docs]def register_norm(custom_norm): """Register a custom norm, gettable with `norms.get`. Args: custom_norm: Custom norm to register. """ if custom_norm.__name__ in globals().keys() or custom_norm.__name__.lower() in globals().keys(): raise ValueError(f"Norm {custom_norm.__name__} already exists. Choose another name.") globals().update({custom_norm.__name__: custom_norm})
[docs]def get(identifier): """Returns a norm class from a string. Returns its input if it is callable (already a :class:`._LayerNorm` for example). Args: identifier (str or Callable or None): the norm identifier. Returns: :class:`._LayerNorm` or None """ if identifier is None: return None elif callable(identifier): return identifier elif isinstance(identifier, str): cls = globals().get(identifier) if cls is None: raise ValueError("Could not interpret normalization identifier: " + str(identifier)) return cls else: raise ValueError("Could not interpret normalization identifier: " + str(identifier))
[docs]def get_complex(identifier): """Like `.get` but returns a complex norm created with `asteroid.complex_nn.OnReIm`.""" norm = get(identifier) if norm is None: return None else: return partial(complex_nn.OnReIm, norm)
Read the Docs v: v0.3.5
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.