API#

Core components#

class swyft.lightning.core.AuxLoss(loss, name)[source]#

Datacloss for storing aditional loss functions that are minimized during optimization

Parameters:
  • loss (torch.Tensor) –

  • name (str) –

class swyft.lightning.core.CoverageSamples(prob_masses, params, parnames)[source]#

Dataclass for storing probability masses samples from coverage tests.

Parameters:
  • prob_masses (torch.Tensor) – Tensor of probability masses in the range [0, 1], \((\text{minibatch}, *\text{logratios_shape})\)

  • params (torch.Tensor) – Corresponding parameter valuess, \((\text{minibatch}, *\text{logratios_shape}, *\text{params_shape})\)

  • parnames (array) – Array of parameter names, \((*\text{logratios_shape})\)

estimate_coverage(parnames, z_max=3.5, bins=50)[source]#

Estimate expected coverage of credible intervals on a grid of credibility values.

Parameters:
  • parnames (str | Sequence[str]) – Names of parameters

  • z_max (float) – upper limit on the credibility level (default 3.5)

  • bins (int) – number of bins used when tabulating z-score

Returns:

Array columns correspond to [nominal z, empirical z, low_err empirical z, hi_err empirical z]

Return type:

np.array (bins, 4)

class swyft.lightning.core.LogRatioSamples(logratios, params, parnames, metadata=<factory>)[source]#

Dataclass for storing samples of estimated log-ratio values in Swyft.

Parameters:
  • logratios (torch.Tensor) – Estimated log-ratios, \((\text{minibatch}, *\text{logratios_shape})\)

  • params (torch.Tensor) – Corresponding parameter valuess, \((\text{minibatch}, *\text{logratios_shape}, *\text{params_shape})\)

  • parnames (array) – Array of parameter names, \((*\text{logratios_shape})\)

  • metadata (dict) – Optional meta-data from inference network etc.

__len__()[source]#

Returns number of stored ratios (minibatch size).

class swyft.lightning.core.SwyftModule(*args, **kwargs)[source]#

This is the central Swyft LightningModule for handling the training of logratio estimators.

Derived classes are supposed to overwrite the forward method in order to implement specific inference tasks.

Note

The forward method takes as arguments the sample batches A and B, which typically include all sample variables. Joined samples correspond to A=B, whereas marginal samples correspond to samples A != B.

Example usage:

class MyNetwork(swyft.SwyftModule):
    def __init__(self):
        self.mlp = swyft.LogRatioEstimator_1dim(4, 4)

    def forward(A, B);
        x = A['x']
        z = B['z']
        logratios = self.mlp(x, z)
        return logratios
class swyft.lightning.core.SwyftTrainer(*args, **kwargs)[source]#

Base class: pytorch_lightning.Trainer

It provides training functionality for swyft.SwyftModule. The functionality is identical to pytorch_lightning.Trainer, see corresponding documentation for more details.

Two additional methods are defined:

  • infer for performing parameter inference tasks with a trained network

  • test_coverage for performing coverage tests

Parameters:
  • args (Any) –

  • kwargs (Any) –

Return type:

Any

infer(model, A, B, return_sample_ratios=True, batch_size=1024)[source]#

Run through model in inference mode.

Parameters:
  • A – Sample, Samples, or dataloader for samples A.

  • B – Sample, Samples, or dataloader for samples B.

  • return_sample_ratios (bool) – If true (default), return results as collated collection of LogRatioSamples objects. Otherwise, return batches.

  • batch_size (int) – batch_size used for Samples provided.

Returns:

Concatenated network output

test_coverage(model, A, B, batch_size=1024, logratio_noise=True)[source]#

Estimate empirical mass.

Parameters:
  • model – network

  • A – truth samples

  • B – prior samples

  • batch_size – batch sized used during network evaluation

  • logratio_noise – Add a small amount of noise to log-ratio estimates, which stabilizes mass estimates for classification tasks.

Returns:

Dict of CoverageSamples objects.

Inference networks#

class swyft.lightning.estimators.LogRatioEstimator_1dim(*args, **kwargs)[source]#

Channeled MLPs for estimating one-dimensional posteriors.

Parameters:

num_features (int) – Number of features

Default module for estimating 1-dim marginal posteriors.

Parameters:
  • num_features (int) – Length of feature vector.

  • num_params – Length of parameter vector.

  • varnames – List of name of parameter vector. If a single string is provided, indices are attached automatically.

class swyft.lightning.estimators.LogRatioEstimator_1dim_Gaussian(*args, **kwargs)[source]#

Estimating posteriors assuming that they are Gaussian.

DEPRECATED: Use LogRatioEstimator_Gaussian instead.

Default module for estimating 1-dim marginal posteriors, using Gaussian approximations.

Parameters:
  • num_params – Length of parameter vector.

  • varnames – List of name of parameter vector. If a single string is provided, indices are attached automatically.

  • momentum (float) – Momentum for running estimate for variance and covariances.

  • minstd (float) – Minimum relative standard deviation of prediction variable. The correlation coefficient will be truncated in the range \(\rho = \pm \sqrt{1-\text{minstd}^2}\)

Note

This module performs running estimates of parameter variances and covariances. There are no learnable parameters. This can cause errors when using the module in isolation without other modules with learnable parameters.

The covariance estimates are based on joined samples only. The first n_batch samples of z are assumed to be joined jointly drawn, where n_batch is the batch size of x.

forward(x, z)[source]#

2-dim Gaussian approximation to marginals and joint, assuming (B, N).

Parameters:
  • x (torch.Tensor) –

  • z (torch.Tensor) –

Return type:

torch.Tensor

class swyft.lightning.estimators.LogRatioEstimator_Autoregressive(*args, **kwargs)[source]#

Conventional autoregressive model, based on swyft.LogRatioEstimator_1dim.

class swyft.lightning.estimators.LogRatioEstimator_Gaussian(*args, **kwargs)[source]#

Estimating posteriors with Gaussian approximation.

Parameters:
  • num_params – Length of parameter vector.

  • varnames – List of name of parameter vector. If a single string is provided, indices are attached automatically.

  • momentum (float) – Momentum of covariance and mean estimates

  • minstd (float) – Minimum standard deviation to enforce numerical stability

forward(a, b)[source]#

Gaussian approximation to marginals and joint, assuming (B, N).

a shape: (B, N, D1) b shape: (B, N, D2)

Parameters:
  • a (torch.Tensor) –

  • b (torch.Tensor) –

class swyft.lightning.estimators.LogRatioEstimator_Ndim(*args, **kwargs)[source]#

Channeled MLPs for estimating multi-dimensional posteriors.

swyft.lightning.estimators.equalize_tensors(a, b)[source]#

Equalize tensors, for matching minibatch size of A and B.

Simulator definition#

class swyft.lightning.simulator.Graph[source]#

Defines the computational graph (DAG) and keeps track of simulation results.

node(parnames, fn, *args)[source]#

Register sampling function.

Parameters:
  • parnames – Name or list of names of sampling variables.

  • fn – Callable that returns the (list of) sampling variable(s).

  • *args – Arguments and keywords arguments that are passed to fn upon evaluation. LazyValues will be automatically evaluated if necessary.

Returns:

Node or tuple of nodes.

class swyft.lightning.simulator.Node(parname, mult_parnames, fn, *inputs)[source]#

Provides lazy evaluation functionality.

Instantiates LazyValue object.

Parameters:
  • trace – Trace instance (to be populated with sample).

  • this_name – Name of the variable that this LazyValue represents.

  • fn_out_names – Name or list of names of variables that fn returns.

  • fn – Callable that returns sample or list of samples.

  • args – Arguments and keyword arguments provided to fn upon evaluation.

  • kwargs – Arguments and keyword arguments provided to fn upon evaluation.

class swyft.lightning.simulator.Sample[source]#

In Swyft, a ‘sample’ is a dictionary with string-type keys and tensor/array-type values.

class swyft.lightning.simulator.Samples[source]#

Handles memory-based samples in Swyft. Samples are stored as dictionary of arrays/tensors with number of samples as first dimension. This class provides a few convenience methods for accessing the samples.

__getitem__(i)[source]#

For integers, return ‘rows’, for string returns ‘columns’.

__len__()[source]#

Number of samples.

get_dataloader(batch_size=1, shuffle=False, on_after_load_sample=None, repeat=None, num_workers=0)[source]#

Generator function to directly generate a dataloader object.

Parameters:
  • batch_size – batch_size for dataloader

  • shuffle – shuffle for dataloader

  • on_after_load_sample – see get_dataset

  • repeat – If not None, Wrap dataset in RepeatDatasetWrapper

get_dataset(on_after_load_sample=None)[source]#

Generator function for SamplesDataset object.

Parameters:

on_after_load_sample – Callable, that is applied to individual samples on the fly.

Returns:

SamplesDataset

class swyft.lightning.simulator.Simulator[source]#

Base class for defining a simulator in Swyft.

This class provides a framework for the definition of the computational graph of the simulation model, and methods for its efficient execution. The computational graph is build in terms of labeled notes in the `build’ method. This method is only ev

Example usage:

class MySim(swyft.Simulator):
    def __init__(self):
        super().__init__()
        self.transform_samples = swyft.to_numpy32

    def build(self, graph):
        z = graph.node('z', lambda: np.random.rand(1))
        x = graph.node('x', lambda z: z + np.random.randn(1)*0.1, z)
build(graph)[source]#

To be overwritten in derived classes (see example usage above).

Note

This method only runs once after Simulator instantiation, during the generation of the very first sample. Afterwards, the graph object itself (and the functions its nodes point to) is used for performing computations.

Parameters:

graph (Graph) – Graph object instance, to be populated with nodes during method execution.

Returns:

None

get_iterator(targets=None, conditions={})[source]#

Generates an iterator. Useful for iterative sampling.

Parameters:
  • targets – Optional list of target sample variables.

  • conditions – Dict or Callable.

get_resampler(targets)[source]#

Generates a resampler. Useful for noise hooks etc.

Parameters:

targets – List of target variables to simulate

Returns:

SimulatorResampler instance.

get_shapes_and_dtypes(targets=None)[source]#

This function run the simulator once and collects information about shapes and data-types of the nodes of the computational graph.

Parameters:

targets (Sequence[str] | None) – Optional list of target sample variables. If None, the full simulation model is run.

Returns:

Dictionary of shapes and dictionary of dtypes

Return type:

(Dict, Dict)

sample(N=None, targets=None, conditions={}, exclude=[])[source]#

Sample from the simulator.

Parameters:
  • N (int | None) – Number of samples to generate. If None, a single sample without sample dimension is returned.

  • targets (Sequence[str] | None) – Optional list of target sample variables to generate. If None, all targets are simulated.

  • conditions (Dict | Callable) – Dict or Callable, conditions on sample variables. A callable will be executed separately for each sample and is expected to return a dictionary with conditions.

  • exclude (Sequence[str] | None) – Optional list of parameters that are excluded from the returned samples. Can be used to reduce memory consumption.

transform_samples(sample)[source]#

Hook for applying transformation to generated samples. Should be overwritten by user.

A typical use-case is to change the data-type of the samples to single precision (if applicable). Swyft provides some convenience functions that can be used in this case. See above `to_numpy32’ for an example.

Parameters:

sample (Sample) – Input sample

Returns:

Transformed sample.

Return type:

Sample

class swyft.lightning.simulator.SimulatorResampler(simulator, targets)[source]#

Handles rerunning part of the simulator. Typically used for on-the-fly calculations during training.

Instantiates SimulatorResampler

Parameters:
  • simulator – The simulator object

  • targets – List of target sample variables that will be resampled

class swyft.lightning.simulator.Switch(parname, options, choice)[source]#

Provides lazy evaluation functionality.

Simulation data#

class swyft.lightning.data.RepeatDatasetWrapper(*args, **kwargs)[source]#
class swyft.lightning.data.SamplesDataset(*args, **kwargs)[source]#

Simple torch dataset based on Samples.

class swyft.lightning.data.SwyftDataModule(*args, **kwargs)[source]#

DataModule to handle simulated data.

Parameters:
  • data – Simulation data

  • val_fraction (float) – Fraction of data used for validation.

  • batch_size (int) – Minibatch size.

  • num_workers (int) – Number of workers for dataloader.

  • shuffle (bool) – Shuffle training data.

  • on_after_load_sample (callable | None) –

Returns:

pytorch_lightning.LightningDataModule

class swyft.lightning.data.ZarrStore(file_path, sync_path=None)[source]#

Storing training data in zarr archive.

reset_length(N, clubber=False)[source]#

Resize store. N >= current store length.

class swyft.lightning.data.ZarrStoreIterableDataset(*args, **kwargs)[source]#
Parameters:

zarr_store (ZarrStore) –

Utility functions#

class swyft.lightning.utils.AdamW[source]#

AdamW with early stopping.

Attributes: - learning_rate (default 1e-3) - weight_decay (default 0.01) - amsgrad (default False) - early_stopping_patience (optional, default 5)

class swyft.lightning.utils.AdamWOneCycleLR[source]#

AdamW with early stopping and OneCycleLR scheduler.

Attributes: - learning_rate (default 1e-3) - early_stopping_patience (optional, default 5)

class swyft.lightning.utils.AdamWReduceLROnPlateau[source]#

AdamW with early stopping and ReduceLROnPlateau scheduler.

Attributes: - learning_rate (default 1e-3) - early_stopping_patience (optional, default 5) - lr_scheduler_factor (optional, default 0.1) - lr_scheduler_patience (optional, default 3)

exception swyft.lightning.utils.SwyftParameterError[source]#

General parameter error in Swyft.

swyft.lightning.utils.best_from_yaml(filepath)[source]#

Get best model from tensorboard log. Useful for reloading trained networks.

Parameters:

filepath – Filename of yaml file (assumed to be saved with to_yaml from ModelCheckpoint)

Returns:

path to best model

swyft.lightning.utils.collate_output(out)[source]#

Turn list of tensors/arrays-value dicts into dict of collated tensors or arrays

swyft.lightning.utils.estimate_coverage(cs_coll, params, z_max=3.5, bins=50)[source]#

Estimate coverage from collection of coverage_samples objects.

swyft.lightning.utils.get_class_probs(lrs_coll, params)[source]#

Return class probabilities for discrete parameters.

Parameters:
  • lrs_coll – Collection of LogRatioSamples objects

  • params (str) – Parameter of interest (must be (0, 1, …, K-1) for K classes)

Returns:

Vector of length K with class probabilities

Return type:

np.Array

swyft.lightning.utils.get_pdf(lrs_coll, params, aux=None, bins=50, smooth=0.0, smooth_prior=False)[source]#

Generate binned PDF based on input

Parameters:
  • lrs_coll – Collection of LogRatioSamples objects.

  • params (str | Sequence[str]) – Parameter names

  • bins (int) – Number of bins

  • smooth (float) – Apply Gaussian smoothing

  • smooth_prior – Smooth prior instead of posterior

Returns:

Returns densities and parameter grid.

Return type:

np.array, np.array

swyft.lightning.utils.get_weighted_samples(lrs_coll, params)[source]#

Returns weighted samples for particular parameter combination.

Parameters:

params (str | Sequence[str]) – (List of) parameter names

Returns:

Parameter and weight tensors

Return type:

(torch.Tensor, torch.Tensor)

swyft.lightning.utils.param_select(parnames, target_parnames, match_exactly=False)[source]#

Find indices of parameters of interest.

The output can be used to for instance select parameter from the LogRatioSamples object like

obj.params[idx1][idx2]

Parameters:
  • parnames\((*logratios_shape, num_params)\)

  • target_parnames – List of parameter names of interest

  • match_exactly (bool) – Only return exact matches (i.e. no partial matches)

Returns:

idx1 (logratio index), idx2 (parameter indices)

Return type:

tuple, list

Truncation bounds#

class swyft.lightning.bounds.RectBoundSampler(distr, bounds=None)[source]#

Sampler for rectangular bound regions.

Parameters:
  • distr – Description of probability distribution, or list of distributions

  • bounds

class swyft.lightning.bounds.RectangleBounds(bounds, parnames)[source]#

Dataclass for storing rectangular bounds.

Parameters:
  • bounds (torch.Tensor) – Bounds

  • parnames (array) – Parameter names

swyft.lightning.bounds.collect_rect_bounds(lrs_coll, parname, parshape, threshold=1e-06)[source]#

Collect rectangular bounds for a parameter of interest.

Parameters:
  • lrs_coll – Collection of LogRatioSamples

  • parname (str) – Name of parameter vector/array of interest

  • parshape (tuple) – Shape of parameter vector/array

  • threshold (float) – Likelihood-ratio selection threshold

swyft.lightning.bounds.get_rect_bounds(logratios, threshold=1e-06)[source]#

Extract rectangular bounds.

Parameters:
  • lrs_coll – Collection of LogRatioSample objects

  • threshold (float) – Threshold value for likelihood ratios.

Plotting#

swyft.plot.plot_corner(lrs_coll, parnames, bins=100, figsize=None, color='k', labels=None, label_args={}, contours_1d=True, fig=None, smooth=0.0, cred_level=[0.68268, 0.9545, 0.9973], truth=None, smooth_prior=False)[source]#

Make a beautiful corner plot.

Parameters:
  • lrs_coll – Collection of swyft.LogRatioSamples objects

  • parnames – List of parameters of interest

  • bins – Number of bins used for histograms.

  • figsize – Size of figure

  • color – Color

  • labels – Optional custom labels, either list or dict.

  • label_args – Custom label arguments

  • contours_1d (bool) – Plot 1-dim contours

  • fig – Figure instance

  • smooth – histogram smoothing

  • cred_level – Credible levels for contours

  • truth – Dictionary with parameters names as keys and true values

  • smooth_prior – Smooth and histogram prior instead of posterior (default False)

Return type:

None

swyft.plot.plot_pair(lrs_coll, parnames=None, bins=100, figsize=None, color='k', labels=None, label_args={}, ncol=None, subplots_kwargs={}, fig=None, smooth=1.0, cred_level=[0.68268, 0.9545, 0.9973], truth=None, smooth_prior=False)[source]#

Make beautiful 2-dim posteriors.

Parameters:
  • lrs_coll – Collection of swyft.LogRatioSamples objects

  • parnames – (Optional) List of parameter pairs of interest

  • bins – Number of bins used for histograms.

  • figsize – Optional size of figure

  • color – Color

  • labels – (Optional) Custom labels

  • label_args – (Pptional) Custom label arguments

  • ncol – (Optional) Number of panel columns

  • subplots_kwargs – Optional arguments for subplots generation.

  • fig – Optional figure instance

  • smooth – Gaussian smothing scale

  • cred_level – Credible levels for contours

  • truth – (Optional) Dictionary with parameters names as keys and true values

  • smooth_prior – Smooth and histogram prior instead of posterior (default False)

Return type:

None

swyft.plot.plot_posterior(lrs_coll, parnames=None, bins=100, figsize=None, color='k', labels=None, label_args={}, ncol=None, subplots_kwargs={}, fig=None, contours=True, smooth=1.0, cred_level=[0.68268, 0.9545, 0.9973], truth=None, smooth_prior=False)[source]#

Make beautiful 1-dim posteriors.

Parameters:
  • lrs_coll – Collection of swyft.LogRatioSamples objects

  • parnames – (Optional) List of parameters of interest

  • bins – Number of bins used for histograms.

  • figsize – Optional size of figure

  • color – Color

  • labels – (Optional) Custom labels

  • label_args – (Pptional) Custom label arguments

  • ncol – (Optional) Number of panel columns

  • subplots_kwargs – Optional arguments for subplots generation.

  • fig – Optional figure instance

  • contours – Plot 1-dim contours

  • smooth – Gaussian smothing scale

  • cred_level – Credible levels for contours

  • truth – (Optional) Dictionary with parameters names as keys and true values

  • smooth_prior – Smooth and histogram prior instead of posterior (default False)

Return type:

None

swyft.plot.plot_pp(coverage_samples, params, z_max=3.5, bins=50, ax=None)[source]#

Make a pp plot.

Parameters:
  • coverage_samples – Collection of CoverageSamples object

  • params (str | Sequence[str]) – Parameters of interest

  • z_max (float) – Maximum value of z.

  • bins (int) – Number of discretization bins.

  • ax – Optional axes instance.

swyft.plot.plot_zz(coverage_samples, params, z_max=3.5, bins=50, ax=None)[source]#

Make a zz plot.

Parameters:
  • coverage_samples – Collection of CoverageSamples object

  • params (str | Sequence[str]) – Parameters of interest

  • z_max (float) – Maximum value of z.

  • bins (int) – Number of discretization bins.

  • ax – Optional axes instance.