Utils¶
Parser utils¶
Asteroid has its own argument parser (built on argparse
) that handles
dict-like structure, created from a config YAML file.
-
asteroid.utils.parser_utils.
isfloat
(value)[source]¶ Computes whether value can be cast to a float.
Parameters: value (str) – Value to check. Returns: bool – Whether value can be cast to a float.
-
asteroid.utils.parser_utils.
isint
(value)[source]¶ Computes whether value can be cast to an int
Parameters: value (str) – Value to check. Returns: bool – Whether value can be cast to an int.
-
asteroid.utils.parser_utils.
parse_args_as_dict
(parser, return_plain_args=False, args=None)[source]¶ Get a dict of dicts out of process parser.parse_args()
Top-level keys corresponding to groups and bottom-level keys corresponding to arguments. Under ‘main_args’, the arguments which don’t belong to a argparse group (i.e main arguments defined before parsing from a dict) can be found.
Parameters: - parser (argparse.ArgumentParser) – ArgumentParser instance containing groups. Output of prepare_parser_from_dict.
- return_plain_args (bool) – Whether to return the output or parser.parse_args().
- args (list) – List of arguments as read from the command line. Used for unit testing.
Returns: dict – Dictionary of dictionaries containing the arguments. Optionally the direct output parser.parse_args().
-
asteroid.utils.parser_utils.
prepare_parser_from_dict
(dic, parser=None)[source]¶ Prepare an argparser from a dictionary.
Parameters: - dic (dict) – Two-level config dictionary with unique bottom-level keys.
- parser (argparse.ArgumentParser, optional) – If a parser already exists, add the keys from the dictionary on the top of it.
Returns: argparse.ArgumentParser – Parser instance with groups corresponding to the first level keys and arguments corresponding to the second level keys with default values given by the values.
-
asteroid.utils.parser_utils.
str2bool
(value)[source]¶ Type to convert strings to Boolean (returns input if not boolean)
Torch utils¶
-
asteroid.utils.torch_utils.
are_models_equal
(model1, model2)[source]¶ Check for weights equality between models.
Parameters: - model1 (nn.Module) – model instance to be compared.
- model2 (nn.Module) – second model instance to be compared.
Returns: bool – Whether all model weights are equal.
-
asteroid.utils.torch_utils.
load_state_dict_in
(state_dict, model)[source]¶ - Strictly loads state_dict in model, or the next submodel.
- Useful to load standalone model after training it with System.
Parameters: - state_dict (OrderedDict) – the state_dict to load.
- model (torch.nn.Module) – the model to load it into
Returns: torch.nn.Module – model with loaded weights.
- # .. note:: Keys in a state_dict look like object1.object2.layer_name.weight.etc
- We first try to load the model in the classic way. If this fail we removes the first left part of the key to obtain object2.layer_name.weight.etc. Blindly loading with strictly=False should be done with some logging of the missing keys in the state_dict and the model.
-
asteroid.utils.torch_utils.
pad_x_to_y
(x, y, axis=-1)[source]¶ Pad first argument to have same size as second argument
Parameters: - x (torch.Tensor) – Tensor to be padded.
- y (torch.Tensor) – Tensor to pad x to.
- axis (int) – Axis to pad on.
Returns: torch.Tensor, x padded to match y’s shape.
-
asteroid.utils.torch_utils.
tensors_to_device
(tensors, device)[source]¶ Transfer tensor, dict or list of tensors to device.
Parameters: - tensors (
torch.Tensor
) – May be a single, a list or a dictionary of tensors. - ( (device) – class: torch.device): the device where to place the tensors.
Returns: Union [
torch.Tensor
, list, tuple, dict] – Same as input but transferred to device. Goes through lists and dicts and transfers the torch.Tensor to device. Leaves the rest untouched.- tensors (
-
asteroid.utils.torch_utils.
to_cuda
(tensors)[source]¶ Transfer tensor, dict or list of tensors to GPU.
Parameters: tensors ( torch.Tensor
, list or dict) – May be a single, a list or a dictionary of tensors.Returns: torch.Tensor
– Same as input but transferred to cuda. Goes through lists and dicts and transfers the torch.Tensor to cuda. Leaves the rest untouched.
Hub utils¶
-
asteroid.utils.hub_utils.
cached_download
(filename_or_url)[source]¶ Download from URL with torch.hub and cache the result in ASTEROID_CACHE.
Parameters: filename_or_url (str) – Name of a model as named on the Zenodo Community page (ex: mpariente/ConvTasNet_WHAM!_sepclean), or an URL to a model file (ex: https://zenodo.org/…/model.pth), or a filename that exists locally (ex: local/tmp_model.pth) Returns: str, normalized path to the downloaded (or not) model
Generic utils¶
-
asteroid.utils.generic_utils.
average_arrays_in_dic
(dic)[source]¶ Take average of numpy arrays in a dictionary.
Parameters: dic (dict) – Input dictionary to take average from Returns: dict – New dictionary with array averaged.
-
asteroid.utils.generic_utils.
flatten_dict
(d, parent_key='', sep='_')[source]¶ Flattens a dictionary into a single-level dictionary while preserving parent keys. Taken from https://stackoverflow.com/questions/6027558/ flatten-nested-dictionaries-compressing-keys?answertab=votes#tab-top
Parameters: Returns: dict – Single-level dictionary, flattened.