swyft.prior¶
- class swyft.prior.Prior(cdf, icdf, log_prob, n_parameters)[source]¶
Fully factorizable prior.
- Parameters:
cdf (Callable) – cumulative density function, aka vu
icdf (Callable) – inverse cumulative density function, aka ppf and uv
log_prob (Callable) – log density function
n_parameters (int) – number of parameters / dimensionality of the prior
Note
The prior is defined through the mapping \(u\to v\), from the Uniform distribution, \(u\sim \text{Unif}(0, 1)\) onto the parameters of interest, \(v\). This mapping corresponds to the inverse cummulative distribution function, and is internally used to perform inverse transform sampling. Sampling happens in the swyft.Bound object.
- static conjugate_tensor_func(function)[source]¶
conjugate a function by converting the input array to a tensor, apply function to the tensor, then convert the output tensor back to an array.
- Parameters:
function (Callable[[Tensor], Tensor]) – callable which takes a torch tensor
- Return type:
Callable[[ndarray], ndarray]
- classmethod from_torch_distribution(distribution)[source]¶
Create a prior from a batched pytorch distribution.
For example,
distribution = torch.distributions.Uniform(-1 * torch.ones(5), 1 * torch.ones(5))
.- Parameters:
distribution (Distribution) – pytorch distribution
- Returns:
Prior
- Return type:
PriorType
- classmethod from_uv(icdf, n_parameters, n_grid_points=10000)[source]¶
Create a prior which depends on
InterpolatedTabulatedDistribution
, i.e. an interpolated representation of the icdf, cdf, and log_prob.Warning
Internally the mapping u -> v is tabulated on a linear grid on the interval [0, 1], with n grid points. In extreme cases, this can lead to approximation errors that can be mitigated by increasing n (in some cases).
- Parameters:
icdf (Callable) – map from hypercube: u -> v. inverse cumulative density function (icdf)
n_parameters (int) – number of parameters / dimensionality of the prior
n_grid_points (int) – number of grid points from which to interpolate the icdf, cdf, and log_prob
- Returns:
Prior
- Return type:
PriorType
- class swyft.prior.PriorTruncator(prior, bound)[source]¶
Samples from a truncated version of the prior and calculates the log_prob.
Note
The prior truncator is defined through a swyft.Bound object, which sample from (subregions of) the hypercube, with swyft.Prior, which maps the samples onto parameters of interest.
Instantiate prior truncator (combination of prior and bound).
- Parameters: