Shortcuts

Source code for asteroid.masknn.activations

from functools import partial
import torch
from torch import nn
from .. import complex_nn


[docs]class Swish(nn.Module): def __init__(self): super(Swish, self).__init__()
[docs] def forward(self, x): return x * torch.sigmoid(x)
[docs]def linear(): return nn.Identity()
[docs]def relu(): return nn.ReLU()
[docs]def prelu(): return nn.PReLU()
[docs]def leaky_relu(): return nn.LeakyReLU()
[docs]def sigmoid(): return nn.Sigmoid()
[docs]def softmax(dim=None): return nn.Softmax(dim=dim)
[docs]def tanh(): return nn.Tanh()
[docs]def gelu(): return nn.GELU()
[docs]def swish(): return Swish()
[docs]def register_activation(custom_act): """Register a custom activation, gettable with `activation.get`. Args: custom_act: Custom activation function to register. """ if custom_act.__name__ in globals().keys() or custom_act.__name__.lower() in globals().keys(): raise ValueError(f"Activation {custom_act.__name__} already exists. Choose another name.") globals().update({custom_act.__name__: custom_act})
[docs]def get(identifier): """Returns an activation function from a string. Returns its input if it is callable (already an activation for example). Args: identifier (str or Callable or None): the activation identifier. Returns: :class:`nn.Module` or None """ if identifier is None: return None elif callable(identifier): return identifier elif isinstance(identifier, str): cls = globals().get(identifier) if cls is None: raise ValueError("Could not interpret activation identifier: " + str(identifier)) return cls else: raise ValueError("Could not interpret activation identifier: " + str(identifier))
[docs]def get_complex(identifier): """Like `.get` but returns a complex activation created with `asteroid.complex_nn.OnReIm`.""" activation = get(identifier) if activation is None: return None else: return partial(complex_nn.OnReIm, activation)
Read the Docs v: v0.3.5
Versions
latest
stable
v0.3.5_b
v0.3.4
v0.3.3
v0.3.2
v0.3.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.