Source code for swyft.prior

from functools import partial
from importlib import import_module
from typing import Callable, Sequence, Type, TypeVar

import numpy as np
import torch
from toolz import compose
from toolz.dicttoolz import keyfilter
from torch.distributions import Normal, Uniform

from swyft.bounds import Bound, UnitCubeBound
from swyft.saveable import StateDictSaveable
from swyft.types import Array
from swyft.utils import array_to_tensor, tensor_to_array

PriorType = TypeVar("PriorType", bound="Prior")
PriorTruncatorType = TypeVar("PriorTruncatorType", bound="PriorTruncator")

[docs]class PriorTruncator(StateDictSaveable): """Samples from a truncated version of the prior and calculates the log_prob. Args: prior: Parameter prior bound: Bound object .. note:: The prior truncator is defined through a swyft.Bound object, which sample from (subregions of) the hypercube, with swyft.Prior, which maps the samples onto parameters of interest. """ def __init__(self, prior: "Prior", bound: Bound) -> None: """Instantiate prior truncator (combination of prior and bound). Args: prior: Prior object. bound: Bound on hypercube. Set 'None' for untruncated priors. """ self.prior = prior if bound is None: bound = UnitCubeBound(prior.n_parameters) self.bound = bound @property def cdf(self) -> Callable: return self.prior.cdf @property def icdf(self) -> Callable: return self.prior.icdf @property def n_parameters(self) -> int: return self.prior.n_parameters
[docs] def sample(self, n_samples: int) -> np.ndarray: """Sample from truncated prior. Args: n_samples: Number of samples to return Returns: Samples: (n_samples, n_parameters) """ u = self.bound.sample(n_samples) return self.prior.icdf(u)
[docs] def log_prob(self, v: np.ndarray) -> np.ndarray: """Evaluate log probability. Args: v: (N, n_parameters) parameter points. Returns: log_prob: (N,) """ u = self.prior.cdf(v) b = np.where(u.sum(axis=-1) == np.inf, 0.0, self.bound(u)) log_prob = np.where( b == 0.0, -np.inf, self.prior.log_prob(v).sum(axis=-1) - np.log(self.bound.volume), ) return log_prob
def state_dict(self) -> dict: return dict(prior=self.prior.state_dict(), bound=self.bound.state_dict()) @classmethod def from_state_dict(cls, state_dict: dict) -> PriorTruncatorType: prior = Prior.from_state_dict(state_dict["prior"]) bound = Bound.from_state_dict(state_dict["bound"]) return cls(prior, bound)
class InterpolatedTabulatedDistribution: def __init__(self, icdf: Callable, n_parameters: int, n_grid_points: int) -> None: r"""Create a distribution based off of a icdf. The distribution is defined by interpolating grid points. Args: icdf: inverse cumulative density function, aka ppf and uv n_parameters: number of parameters, dimensionality of the prior n_grid_points: number of grid points .. warning:: Internally the mapping u -> v is tabulated on a linear grid on the interval [0, 1], with `n` grid points. In extreme cases, this can lead to approximation errors that can be mitigated by increasing `n`. """ self.n_parameters = n_parameters self._grid = np.linspace(0, 1.0, n_grid_points) self._table = self._generate_table(icdf, self._grid, n_parameters) @staticmethod def _generate_table( uv: Callable, grid: np.ndarray, n_parameters: int ) -> np.ndarray: table = [] for x in grid: table.append(uv(np.ones(n_parameters) * x)) return np.array(table).T def cdf(self, v: np.ndarray) -> np.ndarray: """Map onto hypercube: v -> u Args: v: (N, n_parameters) physical parameter array Returns: u: (N, n_parameters) hypercube parameter array """ u = np.empty_like(v) for i in range(self.n_parameters): u[:, i] = np.interp( v[:, i], self._table[i], self._grid, left=np.inf, right=np.inf ) return u def icdf(self, u: np.ndarray) -> np.ndarray: """Map from hypercube: u -> v Args: u: (N, n_parameters) hypercube parameter array Returns: v: (N, n_parameters) physical parameter array """ v = np.empty_like(u) for i in range(self.n_parameters): v[:, i] = np.interp( u[:, i], self._grid, self._table[i], left=np.inf, right=np.inf ) return v def log_prob(self, v: np.ndarray, du: float = 1e-6) -> np.ndarray: """Log probability. Args: v: (N, n_parameters) physical parameter array du: Step-size of numerical derivatives Returns: log_prob: (N, n_parameters) factors of pdf """ dv = np.empty_like(v) u = self.cdf(v) for i in range(self.n_parameters): dv[:, i] = np.interp( u[:, i] + (du / 2), self._grid, self._table[i], left=None, right=None ) dv[:, i] -= np.interp( u[:, i] - (du / 2), self._grid, self._table[i], left=None, right=None ) log_prob = np.where(u == np.inf, -np.inf, np.log(du) - np.log(dv + 1e-300)) return log_prob # TODO this could be improved with some thought # it merely wraps a torch distribution and keeps track of the arguments...
[docs]class Prior(StateDictSaveable): def __init__( self, cdf: Callable, icdf: Callable, log_prob: Callable, n_parameters: int ) -> None: r"""Fully factorizable prior. Args: cdf: cumulative density function, aka vu icdf: inverse cumulative density function, aka ppf and uv log_prob: log density function n_parameters: number of parameters / dimensionality of the prior .. note:: The prior is defined through the mapping :math:`u\to v`, from the Uniform distribution, :math:`u\sim \text{Unif}(0, 1)` onto the parameters of interest, :math:`v`. This mapping corresponds to the inverse cummulative distribution function, and is internally used to perform inverse transform sampling. Sampling happens in the swyft.Bound object. """ self.cdf = cdf self.icdf = icdf self.log_prob = log_prob self.n_parameters = n_parameters self.method = "__init__" self._state_dict = { "method": self.method, "cdf": self.cdf, "icdf": self.icdf, "log_prob": self.log_prob, "n_parameters": self.n_parameters, } self.distribution = None self.get_split = None
[docs] @classmethod def from_torch_distribution( cls: Type[PriorType], distribution: torch.distributions.Distribution, ) -> PriorType: r"""Create a prior from a batched pytorch distribution. For example, ``distribution = torch.distributions.Uniform(-1 * torch.ones(5), 1 * torch.ones(5))``. Args: distribution: pytorch distribution Returns: Prior """ assert ( len(distribution.batch_shape) == 1 ), f"{distribution.batch_shape} must be one dimensional" assert ( len(distribution.event_shape) == 0 ), f"{distribution} must be factorizable and report the log_prob of every dimension (i.e. all dims are in batch_shape)" prior = cls( cdf=cls.conjugate_tensor_func(distribution.cdf), icdf=cls.conjugate_tensor_func(distribution.icdf), log_prob=cls.conjugate_tensor_func(distribution.log_prob), n_parameters=distribution.batch_shape.numel(), ) prior.distribution = distribution prior.get_split = None prior.method = "from_torch_distribution" prior._state_dict = { "method": prior.method, "name": distribution.__class__.__name__, "module": distribution.__module__, "kwargs": keyfilter( lambda x: x in distribution.__class__.arg_constraints, distribution.__dict__, # this depends on all relevant arguments being contained with prior.distribution.__class__.arg_constraints ), } return prior
@classmethod def composite_prior( cls, cdfs: Sequence[Callable], icdfs: Sequence[Callable], log_probs: Sequence[Callable], parameter_dimensions: Sequence[int], ) -> PriorType: assert len(cdfs) == len(icdfs), "there must be as many icdfs as cdfs" assert len(cdfs) == len(log_probs), "there must be as many log_probs as cdfs" assert len(cdfs) == len( parameter_dimensions ), "there must be as many parameter_dimensions as cdfs" n_parameters = sum(parameter_dimensions) parameter_indices = np.cumsum(parameter_dimensions)[:-1] get_split = partial(np.split, indices_or_sections=parameter_indices, axis=1) zip_cdfs = partial(cls.zip_apply, cdfs) zip_icdfs = partial(cls.zip_apply, icdfs) zip_log_probs = partial(cls.zip_apply, log_probs) concatenate = partial(np.concatenate, axis=1) prior = cls( cdf=compose(concatenate, zip_cdfs, get_split), icdf=compose(concatenate, zip_icdfs, get_split), log_prob=compose(concatenate, zip_log_probs, get_split), n_parameters=n_parameters, ) prior.distribution = None prior.method = "composite_prior" prior.get_split = get_split prior._state_dict = { "method": prior.method, "cdfs": cdfs, "icdfs": icdfs, "log_probs": log_probs, "parameter_dimensions": parameter_dimensions, } return prior
[docs] @staticmethod def conjugate_tensor_func( function: Callable[ [ torch.Tensor, ], torch.Tensor, ] ) -> Callable[[np.ndarray,], np.ndarray]: """conjugate a function by converting the input array to a tensor, apply function to the tensor, then convert the output tensor back to an array. Args: function: callable which takes a torch tensor """ return compose(tensor_to_array, function, array_to_tensor)
@staticmethod def zip_apply(functions: Sequence[Callable], arguments: Sequence) -> Sequence: return [f(arg) for f, arg in zip(functions, arguments)]
[docs] @classmethod def from_uv( cls, icdf: Callable, n_parameters: int, n_grid_points: int = 10_000 ) -> PriorType: """Create a prior which depends on ``InterpolatedTabulatedDistribution``, i.e. an interpolated representation of the icdf, cdf, and log_prob. .. warning:: Internally the mapping u -> v is tabulated on a linear grid on the interval [0, 1], with `n` grid points. In extreme cases, this can lead to approximation errors that can be mitigated by increasing `n` (in some cases). Args: icdf: map from hypercube: u -> v. inverse cumulative density function (icdf) n_parameters: number of parameters / dimensionality of the prior n_grid_points: number of grid points from which to interpolate the icdf, cdf, and log_prob Returns: Prior """ raise NotImplementedError("This was too inaccurate.") distribution = InterpolatedTabulatedDistribution( icdf, n_parameters, n_grid_points ) prior = cls( cdf=distribution.v, icdf=distribution.u, log_prob=distribution.log_prob, n_parameters=n_parameters, ) prior.distribution = distribution prior.method = "from_from_uv" prior.state_dict = None # TODO, make like above. return prior
def state_dict(self) -> dict: return self._state_dict @classmethod def from_state_dict(cls, state_dict: dict) -> PriorType: method = state_dict["method"] if method == "__init__": kwargs = keyfilter(lambda x: x != "method", state_dict) return cls(**kwargs) elif method == "composite_prior": kwargs = keyfilter(lambda x: x != "method", state_dict) return getattr(cls, method)(**kwargs) elif method == "from_torch_distribution": name = state_dict["name"] module = state_dict["module"] kwargs = state_dict["kwargs"] distribution = getattr(import_module(module), name) distribution = distribution(**kwargs) return getattr(cls, method)(distribution) else: NotImplementedError()
def get_uniform_prior(low: Array, high: Array) -> Prior: distribution = Uniform(array_to_tensor(low), array_to_tensor(high)) return Prior.from_torch_distribution(distribution) def get_diagonal_normal_prior(loc: Array, scale: Array) -> Prior: distribution = Normal(array_to_tensor(loc), array_to_tensor(scale)) return Prior.from_torch_distribution(distribution) if __name__ == "__main__": pass