# pylint: disable=no-member, not-callable
import logging
import os
import time
from pathlib import Path
from typing import Mapping, Optional, Sequence, Tuple, Union
import fasteners
import numcodecs
import numpy as np
import zarr
import swyft
from swyft.store.simulator import SimulationStatus, Simulator
from swyft.types import Array, ObsShapeType, ParameterNamesType, PathType
from swyft.utils import is_empty
log = logging.getLogger(__name__)
class Filesystem:
metadata = "metadata"
log_lambdas = "metadata/log_lambdas"
samples = "samples"
sims = "samples/sims"
v = "samples/v"
log_w = "samples/log_w"
simulation_status = "samples/simulation_status"
[docs]class Store:
"""Store of sample parameters and simulation outputs.
Based on Zarr, it should be instantiated via its methods `memory_store`,
`directory_store` or `load`.
Args:
zarr_store: Zarr store object.
simulator: simulator object.
sync_path: if specified, it will enable synchronization using file locks (files will be
stored in the given path). Must be accessible to all processes working on the store
and the underlying filesystem must support file locking.
chunksize: the parameters and simulation output will be stored as arrays with the
specified chunk size along the sample dimension (a single chunk will be used for the
other dimensions).
pickle_protocol: pickle protocol number used for storing intensity functions.
from_scratch: if False, load the sample store from the Zarr store provided.
"""
_filesystem = Filesystem
def __init__(
self,
zarr_store: Union[zarr.MemoryStore, zarr.DirectoryStore],
simulator: Optional[Simulator] = None,
sync_path: Optional[PathType] = None,
chunksize: int = 1,
pickle_protocol: int = 4,
from_scratch: bool = True,
) -> None:
self._zarr_store = zarr_store
self._simulator = simulator
self._pickle_protocol = pickle_protocol # TODO: to be deprecated, we will default to 4, which is supported since python 3.4
synchronizer = zarr.ProcessSynchronizer(sync_path) if sync_path else None
self._root = zarr.group(
store=self._zarr_store, synchronizer=synchronizer, overwrite=from_scratch
)
if not from_scratch:
if not {"samples", "metadata"} == self._root.keys():
raise KeyError(
"Invalid Zarr store. It should have keys: ['samples', 'metadata']."
)
print("Loading existing store.")
self._update()
else:
print("Creating new store.")
if simulator is None:
raise ValueError("A simulator is required to setup a new store.")
self._setup_new_zarr_store(
simulator.parameter_names,
simulator.sim_shapes,
self._root,
chunksize=chunksize,
sim_dtype=simulator.sim_dtype,
)
log.debug(" sim_shapes = %s" % str(simulator.sim_shapes))
# a second layer of synchronization is required to grow the store
self._lock = None
if sync_path is not None:
self._setup_lock(sync_path)
[docs] def add(
self, N: int, prior: "swyft.Prior", bound: Optional["swyft.Bound"] = None
) -> None:
"""Adds points to the store.
Args:
N: Number of samples
prior: Prior
bound: Bound object for prior truncation
.. warning::
Calling this method will alter the content of the store by adding
additional points. Currently this cannot be reverted, so use with
care when applying it to the DirectoryStore.
"""
pdf = swyft.PriorTruncator(prior, bound)
# Lock store while adding new points
self.lock()
self._update()
# Generate new points
v_prop = pdf.sample(np.random.poisson(N))
log_lambda_target = pdf.log_prob(v_prop) + np.log(N)
log_lambda_store = self.log_lambda(v_prop)
log_w = np.log(np.random.rand(len(v_prop))) + log_lambda_target
accept_new = log_w > log_lambda_store
v_new = v_prop[accept_new]
log_w_new = log_w[accept_new]
# Anything new?
if sum(accept_new) > 0:
# Add new entries to store
self._append_new_points(v_new, log_w_new)
print("Store: Adding %i new samples to simulator store." % sum(accept_new))
# Update intensity function
self.log_lambdas.resize(len(self.log_lambdas) + 1)
self.log_lambdas[-1] = dict(pdf=pdf.state_dict(), N=N)
log.debug(f" total size of simulator store {len(self)}.")
# Points added, unlock store
self.unlock()
self._update()
# @property
# def zarr_store(self):
# """Return ZarrStore object."""
# return self._zarr_store
def _setup_lock(self, sync_path: PathType) -> None:
"""Setup lock for concurrent access from multiple processes."""
path = os.path.join(sync_path, "cache.lock")
self._lock = fasteners.InterProcessLock(path)
[docs] def lock(self) -> None:
"""Lock store for the current process."""
if self._lock is not None:
log.debug("Store locked")
self._lock.acquire(blocking=True)
[docs] def unlock(self) -> None:
"""Unlock store so that other processes can access it."""
if self._lock is not None:
self._lock.release()
log.debug("Store unlocked")
def _setup_new_zarr_store(
self,
parameter_names: ParameterNamesType,
sim_shapes: ObsShapeType,
root: zarr.Group,
chunksize: int = 1,
sim_dtype: str = "f8",
) -> None: # Adding observational shapes to store
# Parameters
n_parameters = len(parameter_names)
v = root.zeros(
self._filesystem.v,
shape=(0, n_parameters),
chunks=(chunksize, n_parameters),
dtype="f8",
)
v.attrs["parameter_names"] = parameter_names
# Simulations
sims = root.create_group(self._filesystem.sims)
for name, shape in sim_shapes.items():
sims.zeros(
name, shape=(0, *shape), chunks=(chunksize, *shape), dtype=sim_dtype
)
# Random intensity weights
root.zeros(self._filesystem.log_w, shape=(0,), chunks=(chunksize,), dtype="f8")
# Pickled Intensity (prior * N) objects
root.create(
self._filesystem.log_lambdas,
shape=(0,),
dtype=object,
object_codec=numcodecs.Pickle(protocol=self._pickle_protocol),
)
# Simulation status code
root.zeros(
self._filesystem.simulation_status,
shape=(0,),
chunks=(chunksize,),
dtype="int",
)
self._update()
def _update(self) -> None:
self.sims = self._root[self._filesystem.sims]
self.v = self._root[self._filesystem.v]
self.log_w = self._root[self._filesystem.log_w]
self.log_lambdas = self._root[self._filesystem.log_lambdas]
self.sim_status = self._root[self._filesystem.simulation_status]
self.parameter_names = self._root[self._filesystem.v].attrs["parameter_names"]
[docs] def __len__(self) -> int:
"""Returns number of samples in the store."""
self._update()
return len(self.v)
[docs] def __getitem__(self, i: int) -> Tuple[Mapping[str, np.ndarray], np.ndarray]:
"""Returns data store entry with index :math:`i`."""
self._update()
sim = {}
for key, value in self.sims.items():
sim[key] = value[i]
par = self.v[i]
return (sim, par)
def _append_new_points(self, v: Array, log_w: Array) -> None:
"""Append z to zarr_store content and generate new slots for x."""
self._update()
n = len(v)
for key, value in self.sims.items():
shape = list(value.shape)
shape[0] += n
value.resize(*shape)
self._root[self._filesystem.v].append(v)
self._root[self._filesystem.log_w].append(log_w)
m = np.full(n, SimulationStatus.PENDING, dtype="int")
self.sim_status.append(m)
[docs] def log_lambda(self, z: np.ndarray) -> np.ndarray:
"""Intensity function of the store.
Args:
z: Array with the sample parameters. Should have shape (num. samples,
num. parameters per sample).
Returns:
Array with the sample intensities.
"""
self._update()
d = -np.inf * np.ones_like(z[:, 0])
if len(self.log_lambdas) == 0:
return d
for i in range(len(self.log_lambdas)):
pdf = swyft.PriorTruncator.from_state_dict(self.log_lambdas[i]["pdf"])
N = self.log_lambdas[i]["N"]
r = pdf.log_prob(z) + np.log(N)
d = np.where(r > d, r, d)
return d
[docs] def coverage(
self, N: int, prior: "swyft.Prior", bound: Optional["swyft.Bound"] = None
) -> float:
"""Returns fraction of already stored data points.
Args:
N: Number of samples
prior: Prior
bound: Bound object for prior truncation
Returns:
Fraction of samples that is already covered by content of the store.
.. note::
A coverage of zero means that all points need to be newly
simulated. A coverage of 1.0 means that all points are already
available for this (truncated) prior.
.. warning::
Results are Monte Carlo estimated and subject to sampling noise.
"""
pdf = swyft.PriorTruncator(prior, bound)
Nsamples = max(N, 1000) # At least 1000 test samples
self._update()
# Generate new points
v_prop = pdf.sample(np.random.poisson(Nsamples))
log_lambda_target = pdf.log_prob(v_prop) + np.log(N)
log_lambda_store = self.log_lambda(v_prop)
frac = np.where(
log_lambda_target > log_lambda_store,
np.exp(-log_lambda_target + log_lambda_store),
1.0,
).mean()
return frac
[docs] def sample(
self,
N: int,
prior: "swyft.Prior",
bound: Optional["swyft.Bound"] = None,
check_coverage: bool = True,
add: bool = False,
) -> np.ndarray:
"""Return samples from store.
Args:
N: Number of samples
prior: Prior
bound: Bound object for prior truncation
check_coverage: Check whether requested points are contained in the store.
add: If necessary, add requested points to the store.
Returns:
Indices: Index list pointing to the relevant store entries.
"""
if add:
if self.coverage(N, prior, bound=bound) < 1:
self.add(N, prior, bound=bound)
if check_coverage:
if self.coverage(N, prior, bound=bound) < 1.0:
raise RuntimeError(
"Store does not contain enough samples for your requested intensity function `N * prior`."
)
pdf = swyft.PriorTruncator(prior, bound)
self._update()
# Select points from cache
v_store = self.v[:]
log_w_store = self.log_w[:]
log_lambda_target = pdf.log_prob(v_store) + np.log(N)
accept_stored = log_w_store <= log_lambda_target
indices = np.array(range(len(accept_stored)))[accept_stored]
return indices
def _get_indices_to_simulate(
self, indices: Optional[Sequence[int]] = None
) -> np.ndarray:
"""
Determine which samples need to be simulated.
Args:
indices: array with the indices of the samples to
consider. If None, consider all samples.
Returns:
array with the sample indices
"""
status = self.get_simulation_status(indices)
require_simulation = status == SimulationStatus.PENDING
idx = np.flatnonzero(require_simulation)
return indices[idx] if indices is not None else idx
def _set_simulation_status(
self, indices: Sequence[int], status: SimulationStatus
) -> None:
"""
Flag the specified samples with the simulation status.
Args:
indices: array with the indices of the samples to flag
status: new status for the samples
"""
assert status in list(SimulationStatus), f"Unknown status {status}"
current_status = self.sim_status.oindex[indices]
if np.any(current_status == status):
log.warning(
f"Changing simulation status to {status}, but some simulations have already status {status}"
)
self.sim_status.oindex[indices] = status
[docs] def get_simulation_status(
self, indices: Optional[Sequence[int]] = None
) -> np.ndarray:
"""Determine the status of sample simulations.
Args:
indices: List of indices. If None, check the status of all
samples
Returns:
list of simulation statuses
"""
self._update()
return (
self.sim_status.oindex[indices]
if indices is not None
else self.sim_status[:]
)
[docs] def requires_sim(self, indices: Optional[Sequence[int]] = None) -> bool:
"""Check whether there are parameters which require simulation.
Args:
indices: List of indices. If None, check all samples.
Returns:
True if one or more samples require simulations, False otherwise.
"""
self._update()
return self._get_indices_to_simulate(indices).size > 0
def _get_indices_failed_simulations(self) -> np.ndarray:
self._update()
return np.flatnonzero(self.sim_status == SimulationStatus.FAILED)
@property
def any_failed(self) -> bool:
"""Check whether there are parameters which currently lead to a failed simulation."""
self._update()
return self._get_indices_failed_simulations().size > 0
def _add_sim(self, i: int, x: Mapping[str, Array]) -> None:
for k, v in x.items():
self.sims[k][i] = v
self._set_simulation_status(i, SimulationStatus.FINISHED)
def _failed_sim(self, i: int) -> None:
self._update()
self._set_simulation_status(i, SimulationStatus.FAILED)
[docs] def set_simulator(self, simulator: "swyft.Simulator") -> None:
"""(Re)set simulator.
Args:
simulator: Simulator.
"""
if self._simulator is not None:
log.warning("Simulator already set! Overwriting.")
self._simulator = simulator
[docs] def simulate(
self,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
wait_for_results: Optional[bool] = True,
) -> None:
"""Run simulator on parameter store with missing corresponding simulations.
Args:
indices: list of sample indices for which a simulation is required
batch_size: simulations will be submitted in batches of the specified size
wait_for_results: if True, return only when all simulations are done
"""
if self._simulator is None:
log.warning("No simulator specified. No simulations will run.")
return
self.lock()
self._update()
idx = self._get_indices_to_simulate(indices)
self._set_simulation_status(idx, SimulationStatus.RUNNING)
self.unlock()
# Run simulations and collect status
if len(idx) == 0:
log.debug("No simulations required.")
else:
# For the MemoryStore, we need to collect results in memory
collect_in_memory = (
True if isinstance(self._zarr_store, zarr.MemoryStore) else False
)
if collect_in_memory and not wait_for_results:
logging.warning(
"Asynchronous collection of results is not implemented with the MemoryStore"
)
self._simulator._run(
v=self.v,
sims={k: v.oindex for k, v in self.sims.items()},
sim_status=self.sim_status.oindex,
indices=idx,
collect_in_memory=collect_in_memory,
batch_size=batch_size,
)
if wait_for_results:
self.wait_for_simulations(indices)
[docs] def wait_for_simulations(self, indices: Sequence[int]) -> None:
"""Wait for a set of sample simulations to be finished.
Args:
indices: list of sample indices
"""
done = False
while not done:
time.sleep(1)
status = self.get_simulation_status(indices)
done = np.isin(status, [SimulationStatus.FINISHED, SimulationStatus.FAILED])
done = np.all(done)
[docs] @classmethod
def directory_store(
cls,
path: PathType,
simulator: Optional[Simulator] = None,
sync_path: Optional[PathType] = None,
overwrite: bool = False,
) -> "Store":
"""Instantiate a new Store based on a Zarr DirectoryStore.
Args:
path: path to storage directory
simulator: simulator object
sync_path: path for synchronization via file locks (files will be stored in the given path).
It must differ from path, it must be accessible to all processes working on the store,
and the underlying filesystem must support file locking.
overwrite: if True, and a store already exists at the specified path, overwrite it.
Returns:
Store based on a Zarr DirectoryStore
Example:
>>> store = swyft.Store.directory_store(PATH_TO_STORE)
"""
if not Path(path).exists() or overwrite:
zarr_store = zarr.DirectoryStore(path)
sync_path = sync_path or os.path.splitext(path)[0] + ".sync"
return cls(
zarr_store=zarr_store,
simulator=simulator,
sync_path=sync_path,
from_scratch=True,
)
else:
raise FileExistsError(
f"Path {path} exists - set overwrite=True to initialize a new store there."
)
[docs] @classmethod
def memory_store(cls, simulator: Simulator) -> "Store":
"""Instantiate a new Store based on a Zarr MemoryStore.
Args:
simulator: simulator object
Returns:
Store based on a Zarr MemoryStore
Note:
The store returned is in general expected to be faster than an equivalent
store based on the Zarr DirectoryStore, and thus useful for quick
explorations, or for loading data into memory before training.
Example:
>>> store = swyft.Store.memory_store(simulator)
"""
zarr_store = zarr.MemoryStore()
return cls(zarr_store=zarr_store, simulator=simulator, from_scratch=True)
[docs] def save(self, path: PathType) -> None:
"""Save the Store to disk using a Zarr DirectoryStore.
Args:
path: path where to create the Zarr root directory
"""
if isinstance(
self._zarr_store, zarr.DirectoryStore
) and self._zarr_store.path == os.path.abspath(path):
return
path = Path(path)
if path.exists() and not path.is_dir():
raise NotADirectoryError(f"{path} should be a directory")
elif path.exists() and not is_empty(path):
raise FileExistsError(f"{path} is not empty")
else:
path.mkdir(parents=True, exist_ok=True)
zarr_store = zarr.DirectoryStore(path)
zarr.convenience.copy_store(source=self._zarr_store, dest=zarr_store)
[docs] @classmethod
def load(
cls,
path: PathType,
simulator: Optional[Simulator] = None,
sync_path: Optional[PathType] = None,
) -> "Store":
"""Open an existing sample store using a Zarr DirectoryStore.
Args:
path: path to the Zarr root directory
simulator: simulator object
sync_path: path for synchronization via file locks (files will be stored in the given path).
It must differ from path, it must be accessible to all processes working on the store,
and the underlying filesystem must support file locking.
"""
if Path(path).exists():
store = zarr.DirectoryStore(path)
sync_path = sync_path or os.path.splitext(path)[0] + ".sync"
return cls(
zarr_store=store,
simulator=simulator,
sync_path=sync_path,
from_scratch=False,
)
else:
raise FileNotFoundError(f"There is no directory store at {path}.")
[docs] def to_memory(self) -> "Store":
"""Make an in-memory copy of the existing Store using a Zarr MemoryStore."""
memory_store = zarr.MemoryStore()
zarr.convenience.copy_store(source=self._zarr_store, dest=memory_store)
return Store(
zarr_store=memory_store, simulator=self._simulator, from_scratch=False
)