swyft.utils¶
- swyft.utils.array_to_tensor(array, dtype=None, device=None)[source]¶
Converts np.ndarray and torch.Tensor to torch.Tensor with dtype and on device. When dtype is None, unsafe casts all float-type arrays to torch.float32 and all int-type arrays to torch.int64
- Parameters:
array (Union[ndarray, Tensor]) –
dtype (Optional[dtype]) –
device (Optional[Union[device, str]]) –
- Return type:
Tensor
- swyft.utils.get_corner_marginal_indices(n_parameters)[source]¶
produce the marginals for a corner plot
- Parameters:
n_parameters (int) –
- Returns:
marginal_indices_1d, marginal_indices_2d
- Return type:
Tuple[Tuple[Tuple[int, …], …], Tuple[Tuple[int, …], …]]
- swyft.utils.tupleize_marginal_indices(marginal_indices)[source]¶
Reformat input marginal_indices into sorted and hashable standard form: tuples of tuples.
a lone input tuple will be respected as coming from the same marginal lists will assumed to be collections of different marginals
- Parameters:
marginal_indices (Union[int, Sequence[int], Sequence[Sequence[int]]]) –
- Return type:
Tuple[Tuple[int, …], …]