from __future__ import annotations
import logging
from typing import Callable, Optional, TypeVar
import numpy as np
from sklearn.neighbors import BallTree
from swyft.saveable import StateDictSaveable
from swyft.types import ObsType
from swyft.weightedmarginals import WeightedMarginalSamples
log = logging.getLogger(__name__)
BoundType = TypeVar("BoundType", bound="Bound")
[docs]class Bound(StateDictSaveable):
"""A bound region on the hypercube.
.. note::
The Bound object provides methods to sample from subregions of the
hypercube, to evaluate the volume of the constrained region, and to
evaluate the bound.
"""
def __init__(self) -> None:
pass
@property
def volume(self) -> float:
"""Volume of the bound region."""
raise NotImplementedError
@property
def n_parameters(self) -> int:
"""Number of dimensions."""
raise NotImplementedError
[docs] def sample(self, n_samples: int) -> np.ndarray:
"""Sample.
Args:
n_samples: Numbe of samples.
Returns:
s (n_samples x n_parameters)
"""
raise NotImplementedError
def __call__(self, u: np.ndarray):
"""Check whether parameters are within bounds.
Args:
u (n_samples x n_parameters): Parameters on hypercube
"""
raise NotImplementedError
# TODO can we do away with this thing? I think yes.
[docs] @classmethod
def from_state_dict(cls, state_dict) -> BoundType:
"""Instantiate Bound object based on state_dict.
Args:
state_dict (dict): State dictionary
"""
tag = state_dict["tag"]
if tag == "UnitCubeBound":
return UnitCubeBound.from_state_dict(state_dict)
elif tag == "RectangleBound":
return RectangleBound.from_state_dict(state_dict)
elif tag == "BallsBound":
return BallsBound.from_state_dict(state_dict)
elif tag == "CompositBound":
return CompositBound.from_state_dict(state_dict)
else:
raise KeyError
[docs] @staticmethod
def from_marginal_posterior(
n_samples: int,
observation: ObsType,
marginal_posterior: "swyft.inference.marginalposterior.MarginalPosterior",
threshold: float = -13.0,
batch_size: Optional[int] = None,
) -> BoundType:
"""see CompositBound.from_marginal_posterior"""
return CompositBound.from_marginal_posterior(
n_samples,
observation,
marginal_posterior,
threshold,
batch_size,
)
[docs]class UnitCubeBound(Bound, StateDictSaveable):
"""The unit hypercube bound."""
def __init__(self, n_parameters):
"""Initialize unit hypercube bound.
Args:
n_parameters (int): Number of parameters.
"""
self._n_parameters = n_parameters
self._volume = 1.0
@property
def volume(self):
"""The volume of the constrained region."""
return self._volume
@property
def n_parameters(self) -> int:
return self._n_parameters
[docs] def sample(self, n_samples):
"""Generate samples from the bound region.
Args:
n_samples (int): Number of samples
"""
return np.random.rand(n_samples, self.n_parameters)
def __call__(self, u):
"""Evaluate bound.
Args:
u (array): Input array.
Returns:
Ones and zeros
"""
b = np.where(u <= 1.0, np.where(u >= 0.0, 1.0, 0.0), 0.0)
return b.prod(axis=-1)
def state_dict(self):
return dict(tag="UnitCubeBound", n_parameters=self.n_parameters)
[docs] @classmethod
def from_state_dict(cls, state_dict):
return cls(state_dict["n_parameters"])
[docs]class RectangleBound(Bound, StateDictSaveable):
def __init__(self, rec_bounds):
"""Rectangle bound.
Args:
rec_bounds (n x 2 np.ndarray): list of (u_min, u_max) values.
Note: 0 <= u_min < u_max <= 1.
"""
self._rec_bounds = rec_bounds
@property
def volume(self):
V = 1.0
for i in range(self.n_parameters):
V *= self._rec_bounds[i, 1] - self._rec_bounds[i, 0]
return V
@property
def n_parameters(self):
return len(self._rec_bounds)
[docs] def sample(self, n_samples):
u = np.random.rand(n_samples, self.n_parameters)
for i in range(self.n_parameters):
u[:, i] *= self._rec_bounds[i, 1] - self._rec_bounds[i, 0]
u[:, i] += self._rec_bounds[i, 0]
return u
def __call__(self, u):
m = np.ones(len(u))
for i, v in enumerate(self._rec_bounds):
m *= np.where(u[:, i] >= v[0], np.where(u[:, i] <= v[1], 1.0, 0.0), 0.0)
return m > 0.5
def state_dict(self):
return dict(tag="RectangleBound", rec_bounds=self._rec_bounds)
[docs] @classmethod
def from_state_dict(cls, state_dict):
return cls(state_dict["rec_bounds"])
[docs]class BallsBound(Bound, StateDictSaveable):
def __init__(self, points, scale=1.0):
"""Simple mask based on coverage balls around inducing points.
Args:
points (Array): shape (num_points, n_dim)
scale (float): Scale ball size (default 1.0)
"""
assert len(points.shape) == 2
self.X = points
self._n_parameters = self.X.shape[-1]
self.bt = BallTree(self.X, leaf_size=2)
self.epsilon = self._set_epsilon(self.X, self.bt, scale)
self._volume = self._get_volume(self.X, self.epsilon, self.bt)
@property
def volume(self) -> float:
return self._volume
@property
def n_parameters(self) -> int:
return self._n_parameters
@staticmethod
def _set_epsilon(X, bt, scale):
dims = X.shape[-1]
k = [4, 5, 6]
dist, ind = bt.query(X, k=k[dims - 1]) # 4th NN
epsilon = np.median(dist[:, -1]) * scale * 1.5
return epsilon
@staticmethod
def _get_volume(X, epsilon, bt):
n_samples = 100
vol_est = []
d = X.shape[-1]
area = {1: 2 * epsilon, 2: np.pi * epsilon**2}[d]
for i in range(n_samples):
n = np.random.randn(*X.shape)
norm = (n**2).sum(axis=1) ** 0.5
n = n / norm.reshape(-1, 1)
r = np.random.rand(len(X)) ** (1 / d) * epsilon
Y = X + n * r.reshape(-1, 1)
in_bounds = ((Y >= 0.0) & (Y <= 1.0)).prod(axis=1, dtype="bool")
Y = Y[in_bounds]
counts = bt.query_radius(Y, epsilon, count_only=True)
vol_est.append(area * sum(1.0 / counts))
vol_est = np.array(vol_est)
out, err = vol_est.mean(), vol_est.std() / np.sqrt(n_samples)
rel = err / out
if rel > 0.01:
log.debug("WARNING: Rel volume uncertainty is %.4g" % rel)
return out
[docs] def sample(self, n_samples):
counter = 0
samples = []
d = self.X.shape[-1]
while counter < n_samples:
n = np.random.randn(*self.X.shape)
norm = (n**2).sum(axis=1) ** 0.5
n = n / norm.reshape(-1, 1)
r = np.random.rand(len(self.X)) ** (1 / d) * self.epsilon
Y = self.X + n * r.reshape(-1, 1)
in_bounds = ((Y >= 0.0) & (Y <= 1.0)).prod(axis=1, dtype="bool")
Y = Y[in_bounds]
counts = self.bt.query_radius(Y, r=self.epsilon, count_only=True)
p = 1.0 / counts
w = np.random.rand(len(p))
Y = Y[p >= w]
samples.append(Y)
counter += len(Y)
samples = np.vstack(samples)
ind = np.random.choice(range(len(samples)), size=n_samples, replace=False)
return samples[ind]
def __call__(self, u):
u = u.reshape(len(u), -1)
dist, ind = self.bt.query(u, k=1)
return (dist < self.epsilon)[:, 0]
def state_dict(self):
return dict(
tag="BallsBound", points=self.X, epsilon=self.epsilon, volume=self._volume
)
[docs] @classmethod
def from_state_dict(cls, state_dict):
obj = cls.__new__(cls)
obj.X = state_dict["points"]
assert len(obj.X.shape) == 2
obj._n_parameters = obj.X.shape[-1]
obj.epsilon = state_dict["epsilon"]
obj._volume = state_dict["volume"]
obj.bt = BallTree(obj.X, leaf_size=2)
return obj
[docs]class CompositBound(Bound, StateDictSaveable):
"""Composit bound object. Product of multiple bounds."""
def __init__(self, bounds_map, n_parameters):
"""
Args:
bounds_map (dict): Dictionary mapping indices like (0, 3) etc --> bounds
n_parameters (int): Length of parameter vector.
"""
self._bounds = bounds_map
self._n_parameters = n_parameters
[docs] def sample(self, n_samples):
results = -np.ones((n_samples, self.n_parameters))
for k, v in self._bounds.items():
results[:, np.array(k)] = v.sample(n_samples)
return results
@property
def volume(self) -> float:
volume = 1.0
for k, v in self._bounds.items():
volume *= v.volume
return volume
@property
def n_parameters(self) -> int:
return self._n_parameters
def __call__(self, u):
res = []
for k, v in self._bounds.items():
r = v(u[:, np.array(k)])
res.append(r)
return sum(res) == len(res)
# - Function: Generate sample from posterior
# - Constraints are based on p(u|z)/p(u), and should be (different from what we have in the paper???)
# - That means I need weights without prior corrections, can be an option to switch this on or off
# - Samples should be samples from the hypercube
# - Use sampled points above a threshold for generating Rec bound and BallBounds, directly based on points
# - Function: Return isolated ratio function & bound object from Posteriors object
# - Can be used in sampling
[docs] @classmethod
def from_weighted_samples(
cls,
weighted_samples: WeightedMarginalSamples,
cdf: Callable,
n_parameters: int,
threshold: float,
) -> BoundType:
"""create a new bound object from weighted samples and the cdf
Args:
weighted_samples: log weighted samples
cdf: transforms from v to u
n_parameters: number of total parameters
threshold: above which log weight do we bound? -13 is standard.
Returns:
a bound object based on the above
"""
bounds = {}
u = cdf(weighted_samples.v)
for marginal_index in weighted_samples.marginal_indices:
logw, _ = weighted_samples.get_logweight_marginal(marginal_index)
mask = logw - logw.max() > threshold
u_above_threshold = u[mask][:, marginal_index]
bounds[marginal_index] = BallsBound(u_above_threshold)
return cls(bounds, n_parameters)
[docs] @classmethod
def from_marginal_posterior(
cls,
n_samples: int,
observation: ObsType,
marginal_posterior: "swyft.inference.marginalposterior.MarginalPosterior",
threshold: float,
batch_size: Optional[int] = None,
) -> BoundType:
"""create a new bound object from a marginal posterior by sampling to estimate the log_prob contours
Args:
n_samples: number of samples to estimate with
observation: single observation to define the bounds
marginal_posterior: marginal posterior object
threshold: above which log weight do we bound? -13 is standard
batch_size: when evaluating the log_prob, what batch size to use
Returns:
a bound object based on the above
"""
weighted_samples = marginal_posterior.weighted_sample(
n_samples, observation, batch_size
)
return cls.from_weighted_samples(
weighted_samples=weighted_samples,
cdf=marginal_posterior.prior.cdf,
n_parameters=marginal_posterior.prior.n_parameters,
threshold=threshold,
)
def state_dict(self):
state_dict = dict(
tag="CompositBound",
n_parameters=self.n_parameters,
bounds={k: v.state_dict() for k, v in self._bounds.items()},
)
return state_dict
[docs] @classmethod
def from_state_dict(cls, state_dict):
bounds = {k: Bound.from_state_dict(v) for k, v in state_dict["bounds"].items()}
n_parameters = state_dict["n_parameters"]
return cls(bounds, n_parameters)