Source code for asteroid.masknn.activations
import torch
from torch import nn
[docs]class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
[docs] def forward(self, x):
return x * torch.sigmoid(x)
[docs]def linear():
return nn.Identity()
[docs]def relu():
return nn.ReLU()
[docs]def prelu():
return nn.PReLU()
[docs]def leaky_relu():
return nn.LeakyReLU()
[docs]def sigmoid():
return nn.Sigmoid()
[docs]def softmax(dim=None):
return nn.Softmax(dim=dim)
[docs]def tanh():
return nn.Tanh()
[docs]def gelu():
return nn.GELU()
[docs]def swish():
return Swish()
[docs]def register_activation(custom_act):
""" Register a custom activation, gettable with `activation.get`.
Args:
custom_act: Custom activation function to register.
"""
if custom_act.__name__ in globals().keys() or custom_act.__name__.lower() in globals().keys():
raise ValueError(f"Activation {custom_act.__name__} already exists. Choose another name.")
globals().update({custom_act.__name__: custom_act})
[docs]def get(identifier):
""" Returns an activation function from a string. Returns its input if it
is callable (already an activation for example).
Args:
identifier (str or Callable or None): the activation identifier.
Returns:
:class:`nn.Module` or None
"""
if identifier is None:
return None
elif callable(identifier):
return identifier
elif isinstance(identifier, str):
cls = globals().get(identifier)
if cls is None:
raise ValueError("Could not interpret activation identifier: " + str(identifier))
return cls
else:
raise ValueError("Could not interpret activation identifier: " + str(identifier))