from typing import Dict, Hashable, Optional, Union
import numpy as np
import torch
from swyft.types import Array, Device
def dict_to_device(d, device, non_blocking=False):
return {k: v.to(device, non_blocking=non_blocking) for k, v in d.items()}
def dict_array_to_tensor(
d, device="cpu", non_blocking=False, indices=slice(0, None)
) -> Dict[Hashable, torch.Tensor]:
return {
k: array_to_tensor(v[indices]).to(device, non_blocking=non_blocking)
for k, v in d.items()
}
np_bool_types = [bool]
np_int_types = [np.int8, np.int16, np.int32, np.int64]
np_float_types = [np.float32, np.float64]
np_complex_types = [np.complex64, np.complex128]
torch_bool_types = [torch.bool]
torch_int_types = [torch.int8, torch.int16, torch.int32, torch.int64]
torch_float_types = [torch.float32, torch.float64]
torch_complex_types = [torch.complex64, torch.complex128]
[docs]def array_to_tensor(
array: Array, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
) -> torch.Tensor:
"""Converts np.ndarray and torch.Tensor to torch.Tensor with dtype and on device.
When dtype is None, unsafe casts all float-type arrays to torch.float32 and all int-type arrays to torch.int64
"""
if not isinstance(array, (np.ndarray, torch.Tensor)):
array = np.asarray(array)
input_dtype = array.dtype
if isinstance(input_dtype, np.dtype):
if dtype is None:
if input_dtype in np_float_types:
dtype = torch.float32
elif input_dtype in np_int_types:
dtype = torch.int64
elif input_dtype in np_bool_types:
dtype = torch.bool
elif input_dtype in np_complex_types:
dtype = torch.complex64
else:
raise TypeError(
f"{input_dtype} was not a supported numpy int, float, bool, or complex."
)
return torch.from_numpy(array).to(dtype=dtype, device=device)
elif isinstance(input_dtype, torch.dtype):
if dtype is None:
if input_dtype in torch_float_types:
dtype = torch.float32
elif input_dtype in torch_int_types:
dtype = torch.int64
elif input_dtype in torch_bool_types:
dtype = torch.bool
elif input_dtype in torch_complex_types:
dtype = torch.complex64
else:
raise TypeError(
f"{input_dtype} was not a supported torch int, float, bool, or complex."
)
return array.to(dtype=dtype, device=device)
else:
raise TypeError(
f"{input_dtype} was not recognized as a supported numpy.dtype or torch.dtype."
)
def tensor_to_array(
tensor: Array, dtype: Optional[np.dtype] = None, copy: bool = True
) -> np.ndarray:
if isinstance(tensor, torch.Tensor):
out = np.asarray(tensor.detach().cpu().numpy(), dtype=dtype)
else:
out = np.asarray(tensor, dtype=dtype)
if copy:
return out.copy()
else:
return out
def tobytes(x: Array) -> Array:
if isinstance(x, np.ndarray):
return x.tobytes()
elif isinstance(x, torch.Tensor):
return x.numpy().tobytes()
else:
raise TypeError(f"{type(x)} does not support tobytes.")
def _all_finite(x: Array) -> Array:
if isinstance(x, torch.Tensor):
return torch.all(torch.isfinite(x))
else:
return np.all(np.isfinite(x))
def all_finite(x: Union[dict, torch.Tensor, np.ndarray, list]) -> bool:
if isinstance(x, dict):
return all(_all_finite(v) for v in x.values())
elif isinstance(x, (torch.Tensor, np.ndarray)):
return _all_finite(x)
elif isinstance(x, list):
return all(_all_finite(v) for v in x)
else:
raise NotImplementedError("That type is not yet implemented.")