This page was generated from notebooks/0A - SwyftModule.ipynb.

A - Building inference networks with SwyftModule#

Authors: Noemi Anau Montel, James Alvey, Christoph Weniger

Last update: 15 September 2023

Purpose: We go through the basic steps of performing parameter inference with Swyft.

Note: As is always the case when dealing with training of artificial neural networks, obtaining optimal results depends on a proper setting of training and network parameters. This will be discussed later.

Key take-away messages: Understand swyft’s main functions for data and dataloaders (swyft.Samples, swyft.SwyftDataModule), networks (swyft.SwyftModule) and training/inference (swyft.SwyftTrainer).

Code#

Setup#

First we need some imports.

[1]:
import numpy as np
import pylab as plt
import torch

import swyft
DEVICE = 'gpu' if torch.cuda.is_available() else 'cpu'
[2]:
torch.manual_seed(0)
np.random.seed(0)

Training data#

Now we generate training data. As a simple example, we consider the model

\[x = z + \epsilon\]

where the parameter \(z \sim \mathcal{U}(-1, 1)\) is drawn from the uniform distribution, and \(\epsilon \sim \mathcal{N}(\mu = 0, \sigma = 0.1)\) is small additive noise. We are interested in the posterior of \(z\) given a measurement of parameter \(x\).

Let us generate some samples first, here by using basic numpy functionality.

[3]:
N = 10_000  # Number of samples
z = np.random.rand(N, 1)*2-1  # Uniform prior over [-1, 1]
x = z + np.random.randn(N, 1)*0.2

Note that the shape of the z and x arrays is (n_samples, 1). The first dimension corresponds to the number of samples. Subsequent dimensions correspond to data and parameter shapes (here simply one in both cases).

It is instructive to visualize the training data. - Blue dots: generated \((x, z)\) pairs (“jointly sampled”) - Red dots: scrambled \((x, z)\) pairs (“marginally sampled”).

[4]:
plt.scatter(z[:,0], np.random.permutation(x[:,0]), alpha = 0.5, marker='.', color='r', label = "Marginal samples"); plt.xlabel("z"); plt.ylabel("x");
plt.scatter(z[:,0], x[:,0], alpha = 0.5, marker='.', label = "Joined samples"); plt.xlabel("z"); plt.ylabel("x"); plt.legend(loc = 1)
[4]:
<matplotlib.legend.Legend at 0x17ed3cc40>
../_images/notebooks_0A_-_SwyftModule_10_1.png

Training data that is kept in memory is stored in a swyft.Samples object.

[5]:
samples = swyft.Samples(x = x, z = z)

Inference network#

The inference network is an instance of swyft.SwyftModule. It estimates ratios of the form

\[r(a; b) = \frac{p(a, b)}{p(a)p(b)}\;,\]

where \(a\) and \(b\) refer to any components of the training data. In the below example, we set \(a \to x\) and \(b \to z\).

[6]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        self.logratios = swyft.LogRatioEstimator_1dim(num_features = 1, num_params = 1, varnames = 'z', num_blocks = 4)

    def forward(self, A, B):
        return self.logratios(A['x'], B['z'])

Swyft comes with a few default networks. Here we use swyft.LogRatioEstimator_1dim, which is a dense network that estimates (potentially multiple) one-dimensional posteriors. In the present example, the length of the parameter vector (num_params) and data vectors (num_features) are one.

Training#

Training is now done using the SwyftTrainer class, which extends pytorch_lightning.Trainer by methods like infer (see below). Since our training data is double precision in this example, we have to set precision = 64.

[7]:
trainer = swyft.SwyftTrainer(accelerator = DEVICE, precision = 64)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/trainer/setup.py:200: UserWarning: MPS available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='mps', devices=1)`.
  rank_zero_warn(

The swyft.SwyftDataModule class provides convenience functions to generate data loaders for training and validation data. We preserve 20% of samples for validation.

[8]:
dm = swyft.SwyftDataModule(samples, batch_size=128)

Finally we instantiate our inference network…

[9]:
network = Network()
trainer.fit(network, dm)
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py:94: PossibleUserWarning: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
  rank_zero_warn(
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint

  | Name      | Type                   | Params
-----------------------------------------------------
0 | logratios | LogRatioEstimator_1dim | 34.6 K
-----------------------------------------------------
34.6 K    Trainable params
0         Non-trainable params
34.6 K    Total params
0.276     Total estimated model params size (MB)
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Reloading best model: /Users/cweniger/Documents/swyft/notebooks/lightning_logs/version_18/checkpoints/epoch=13-step=882.ckpt

…and start training.

Inference#

Let’s assume that we measured the values \(x=0.0\). We put this observation in a swyft.Sample object (representing a single sample).

[10]:
x0 = 0.0
obs = swyft.Sample(x = np.array([x0]))

Since the inference network estimates the (logarithm of the) posterior-to-prior ratio, we can obtain weighted posterior samples by running many prior samples through the inference network. To this end, we first generate prior samples.

[11]:
prior_samples = swyft.Samples(z = np.random.rand(3000, 1)*2-1)

Then we evaluate the inference network by using the infer method of the swyft.Trainer object.

[12]:
predictions = trainer.infer(network, obs, prior_samples)
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: EarlyStopping, ModelCheckpoint
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py:173: UserWarning: Lightning couldn't infer the indices fetched for your dataloader.
  warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
[13]:
swyft.plot_posterior(predictions, 'z[0]', smooth = 2)
for offset in [-0.6, -0.4, -0.2, 0, 0.2, 0.4, 0.6]:
    plt.axvline(x0+offset, color='g', ls = ':')
plt.axvline(x0)
[13]:
<matplotlib.lines.Line2D at 0x17fc662b0>
../_images/notebooks_0A_-_SwyftModule_29_1.png

Voilà. You performed a basic parameter inference task with neural ratio estimation. The gray regions should the 68.3%, 95.5% and 99.7% highest density credible regions. The correct regions are indicated by the green vertical lines. The result will not be perfect, but we will discuss later possible ways to improve.

Exercises#

  1. The swyft.Samples object is compatible with array slicing operatios (details). Extract the first 3 samples from the samples object by using numpy array slicing notation.

[14]:
# Your results goes here
  1. The return type of swyft.LogRatioEstimator_1dim, and in the above example of the inference network, is swyft.LogRatioSamples.

  • Confirm that this is also the type of the predictions returned by the infer method.

  • Extract \(\ln r(x; z)\) (contained in the logratios variable) as well as the parameter \(z\) (contained in the params variable) from the predictions.

  • Plot \(\ln r(x;z)\) vs \(z\) using plt.scatter (make sure to pass everything as vectors).

[15]:
# Your results goes here

This page was generated from notebooks/0A - SwyftModule.ipynb.