Source code for swyft.store.dataset

from typing import Callable, Hashable, Optional, Sequence, Tuple
from warnings import warn

import numpy as np
import torch

import swyft
from swyft.saveable import StateDictSaveable
from swyft.types import Array, ObsType, ParameterNamesType, PathType
from swyft.utils.array import array_to_tensor


class SimpleDataset(torch.utils.data.Dataset):
    """Dataset which merely keeps track of observation and parameter data in one place"""

    def __init__(self, observations: ObsType, us: Array, vs: Array) -> None:
        """
        Args:
            observations: dictionary of batched obserations
            us: array or tensor of unit cube parameters
            vs: array or tensor of natural parameters
        """
        super().__init__()
        b = us.shape[0]
        assert (
            vs.shape[0] == b
        ), "the us and vs arrays do not have the same batch dimension"
        assert all(
            [b == x.shape[0] for x in observations.values()]
        ), "the observation values do not have the same batch dimension as us and vs"
        self.observations = observations
        self.us = us
        self.vs = vs
        self._len = b

    def __getitem__(self, idx) -> Tuple[ObsType, torch.Tensor, torch.Tensor]:
        return (
            {
                key: array_to_tensor(val[idx, ...])
                for key, val in self.observations.items()
            },
            array_to_tensor(self.us[idx, ...]),
            array_to_tensor(self.vs[idx, ...]),
        )

    def __len__(self) -> int:
        return self._len


[docs]class Dataset(torch.utils.data.Dataset, StateDictSaveable): """Dataset for access to swyft.Store.""" def __init__( self, N: int, prior: swyft.Prior, store: "swyft.store.store.Store", bound: Optional[swyft.Bound] = None, simhook: Optional[Callable[..., ObsType]] = None, simkeys: Optional[Sequence[Hashable]] = None, ) -> None: """ Args: N: Number of samples. prior: Parameter prior. store: Store reference. simhook: Applied on-the-fly to each sample. simhook(x, v) simkeys: List of simulation keys that should be exposed (None means that all store sims are exposed) .. note:: swyft.Dataset is essentially a list of indices that point to corresponding entries in the swyft.Store. It is a daughter class of torch.utils.data.Dataset, and can be used directly for training. Due to the statistical nature of the Store, the returned number of samples is effectively drawn from a Poisson distribution with mean N. """ super().__init__() self._prior_truncator = swyft.PriorTruncator(prior, bound) self.indices = store.sample(N, prior, bound=bound) self._store = store self._simhook = simhook self._simkeys = simkeys if simkeys else list(self._store.sims) if self.requires_sim: warn( "The store requires simulation for some of the indices in this Dataset." )
[docs] def __len__(self) -> int: """Return length of dataset.""" return len(self.indices)
@property def prior(self) -> swyft.Prior: """Return prior of dataset.""" return self._prior_truncator.prior @property def bound(self) -> swyft.Bound: """Return bound of truncated prior of dataset (swyft.Bound).""" return self._prior_truncator.bound
[docs] def simulate( self, batch_size: Optional[int] = None, wait_for_results: bool = True ) -> None: """Trigger simulations for points in the dataset. Args: batch_size: Number of batched simulations. wait_for_results: What for simulations to complete before returning. """ self._store.simulate( self.indices, batch_size=batch_size, wait_for_results=wait_for_results )
@property def requires_sim(self) -> bool: """Check if simulations are required for points in the dataset.""" return self._store.requires_sim(self.indices) @property def v(self) -> np.ndarray: """Return all parameters as (n_points, n_parameters) array.""" return np.array([self._store.v[i] for i in self.indices]) @property def parameter_names(self) -> ParameterNamesType: """Return parameter names (inherited from store and simulator).""" return self._store.parameter_names
[docs] def __getitem__(self, idx) -> Tuple[ObsType, torch.Tensor, torch.Tensor]: """Return datastore entry.""" i = self.indices[idx] x = {k: self._store.sims[k][i] for k in self._simkeys} v = self._store.v[i] if self._simhook is not None: x = self._simhook(x, v) u = self.prior.cdf(v.reshape(1, -1)).flatten() return ( {key: array_to_tensor(val) for key, val in x.items()}, array_to_tensor(u), array_to_tensor(v), )
def state_dict(self) -> dict: return dict( indices=self.indices, prior_truncator=self._prior_truncator.state_dict(), simhook=bool(self._simhook), simkeys=self._simkeys, ) @classmethod def from_state_dict(cls, state_dict, store, simhook=None): obj = cls.__new__(cls) obj._prior_truncator = swyft.PriorTruncator.from_state_dict( state_dict["prior_truncator"] ) obj.indices = state_dict["indices"] obj._store = store obj._simhook = simhook obj._simkeys = state_dict["simkeys"] if simhook is None and state_dict["simhook"] is True: warn( "A simhook was specified when the dataset was saved, but is missing now." ) elif simhook is not None and state_dict["simhook"] is False: warn( "A simhook was specified, but no simhook was specified when the Dataset was saved." ) return obj
[docs] def save(self, filename: PathType) -> None: """ .. note:: The store and the simhook are not saved. They must be loaded independently by the user. """ sd = self.state_dict() torch.save(sd, filename)
[docs] @classmethod def load( cls, filename: PathType, store: "swyft.Store", simhook: Callable[..., ObsType] = None, ): """Load dataset. Args: filename store: Corresponding datastore. simhook: Simulation hook. .. warning:: Make sure that the store is the same that was originally used for creating the dataset. """ sd = torch.load(filename) return cls.from_state_dict(sd, store, simhook=simhook)