asteroid.utils.torch_utils module¶
-
asteroid.utils.torch_utils.
to_cuda
(tensors)[source]¶ Transfer tensor, dict or list of tensors to GPU.
Parameters: tensors ( torch.Tensor
, list or dict) – May be a single, a list or a dictionary of tensors.Returns: torch.Tensor
– Same as input but transferred to cuda. Goes through lists and dicts and transfers the torch.Tensor to cuda. Leaves the rest untouched.
-
asteroid.utils.torch_utils.
tensors_to_device
(tensors, device)[source]¶ Transfer tensor, dict or list of tensors to device.
Parameters: - tensors (
torch.Tensor
) – May be a single, a list or a dictionary of tensors. - ( (device) – class: torch.device): the device where to place the tensors.
Returns: Union [
torch.Tensor
, list, tuple, dict] – Same as input but transferred to device. Goes through lists and dicts and transfers the torch.Tensor to device. Leaves the rest untouched.- tensors (
-
asteroid.utils.torch_utils.
get_device
(tensor_or_module, default=None)[source]¶ Get the device of a tensor or a module.
Parameters: - tensor_or_module (Union[torch.Tensor, torch.nn.Module]) – The object to get the device from. Can be a
torch.Tensor
, atorch.nn.Module
, or anything else that has adevice
attribute or aparameters() -> Iterator[torch.Tensor]
method. - default (Optional[Union[str, torch.device]]) – If the device can not be
determined, return this device instead. If
None
(the default), raise aTypeError
instead.
Returns: torch.device – The device that
tensor_or_module
is on.- tensor_or_module (Union[torch.Tensor, torch.nn.Module]) – The object to get the device from. Can be a
-
asteroid.utils.torch_utils.
is_tracing
()[source]¶ Returns
True
in tracing (if a function is called during the tracing of code withtorch.jit.trace
) andFalse
otherwise.
-
asteroid.utils.torch_utils.
script_if_tracing
(fn)[source]¶ Compiles
fn
when it is first called during tracing.torch.jit.script
has a non-negligible start up time when it is first called due to lazy-initializations of many compiler builtins. Therefore you should not use it in library code. However, you may want to have parts of your library work in tracing even if they use control flow. In these cases, you should use@torch.jit.script_if_tracing
to substitute fortorch.jit.script
.Parameters: fn – A function to compile. Returns: If called during tracing, a ScriptFunction
created by ` torch.jit.script` is returned. Otherwise, the original functionfn
is returned.
-
asteroid.utils.torch_utils.
pad_x_to_y
(x: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b5c110>, y: <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b5c150>, axis: int = -1) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f85c9b5c310>[source]¶ Right-pad or right-trim first argument to have same size as second argument
Parameters: - x (torch.Tensor) – Tensor to be padded.
- y (torch.Tensor) – Tensor to pad x to.
- axis (int) – Axis to pad on.
Returns: torch.Tensor, x padded to match y’s shape.
-
asteroid.utils.torch_utils.
load_state_dict_in
(state_dict, model)[source]¶ - Strictly loads state_dict in model, or the next submodel.
- Useful to load standalone model after training it with System.
Parameters: - state_dict (OrderedDict) – the state_dict to load.
- model (torch.nn.Module) – the model to load it into
Returns: torch.nn.Module – model with loaded weights.
Note
Keys in a state_dict look like
object1.object2.layer_name.weight.etc
We first try to load the model in the classic way. If this fail we removes the first left part of the key to obtainobject2.layer_name.weight.etc
. Blindly loading withstrictly=False
should be done with some logging of the missing keys in the state_dict and the model.
-
asteroid.utils.torch_utils.
are_models_equal
(model1, model2)[source]¶ Check for weights equality between models.
Parameters: - model1 (nn.Module) – model instance to be compared.
- model2 (nn.Module) – second model instance to be compared.
Returns: bool – Whether all model weights are equal.
-
asteroid.utils.torch_utils.
jitable_shape
(tensor)[source]¶ Gets shape of
tensor
astorch.Tensor
type for jit compilerNote
Returning
tensor.shape
oftensor.size()
directly is not torchscript compatible as return type would not be supported.Parameters: tensor (torch.Tensor) – Tensor Returns: torch.Tensor – Shape of tensor