asteroid.utils.torch_utils module¶
-
asteroid.utils.torch_utils.
are_models_equal
(model1, model2)[source]¶ Check for weights equality between models.
Parameters: - model1 (nn.Module) – model instance to be compared.
- model2 (nn.Module) – second model instance to be compared.
Returns: bool – Whether all model weights are equal.
-
asteroid.utils.torch_utils.
load_state_dict_in
(state_dict, model)[source]¶ - Strictly loads state_dict in model, or the next submodel.
- Useful to load standalone model after training it with System.
Parameters: - state_dict (OrderedDict) – the state_dict to load.
- model (torch.nn.Module) – the model to load it into
Returns: torch.nn.Module – model with loaded weights.
- # .. note:: Keys in a state_dict look like object1.object2.layer_name.weight.etc
- We first try to load the model in the classic way. If this fail we removes the first left part of the key to obtain object2.layer_name.weight.etc. Blindly loading with strictly=False should be done with some logging of the missing keys in the state_dict and the model.
-
asteroid.utils.torch_utils.
pad_x_to_y
(x, y, axis=-1)[source]¶ Pad first argument to have same size as second argument
Parameters: - x (torch.Tensor) – Tensor to be padded.
- y (torch.Tensor) – Tensor to pad x to.
- axis (int) – Axis to pad on.
Returns: torch.Tensor, x padded to match y’s shape.
-
asteroid.utils.torch_utils.
tensors_to_device
(tensors, device)[source]¶ Transfer tensor, dict or list of tensors to device.
Parameters: - tensors (
torch.Tensor
) – May be a single, a list or a dictionary of tensors. - ( (device) – class: torch.device): the device where to place the tensors.
Returns: Union [
torch.Tensor
, list, tuple, dict] – Same as input but transferred to device. Goes through lists and dicts and transfers the torch.Tensor to device. Leaves the rest untouched.- tensors (
-
asteroid.utils.torch_utils.
to_cuda
(tensors)[source]¶ Transfer tensor, dict or list of tensors to GPU.
Parameters: tensors ( torch.Tensor
, list or dict) – May be a single, a list or a dictionary of tensors.Returns: torch.Tensor
– Same as input but transferred to cuda. Goes through lists and dicts and transfers the torch.Tensor to cuda. Leaves the rest untouched.