This page was generated from notebooks/00 - Swyft in 15 minutes.ipynb.

Swyft in 15 Minutes#

We discuss seven key steps of a typical Swyft workflow.

1. Installing Swyft#

We can use pip to install the lightning branch (latest development branch) of Swyft.

[2]:
#!pip install https://github.com/undark-lab/swyft.git
[3]:
import numpy as np
from scipy import stats
import pylab as plt
import torch
import swyft
DEVICE = 'gpu' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
np.random.seed(0)

2. Define the Simulator#

Next we define a simulator class, which specifies the computational graph of our simulator.

[4]:
class Simulator(swyft.Simulator):
    def __init__(self):
        super().__init__()
        self.transform_samples = swyft.to_numpy32
        self.x = np.linspace(-1, 1, 10)

    def build(self, graph):
        z = graph.node('z', lambda: np.random.rand(2)*2-1)
        f = graph.node('f', lambda z: z[0] + z[1]*self.x, z)
        x = graph.node('x', lambda f: f + np.random.randn(10)*0.1, f)

sim = Simulator()
samples = sim.sample(N = 10000)

3. Define the SwyftModule#

[5]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Linear(10, 2)
        self.logratios1 = swyft.LogRatioEstimator_1dim(num_features = 2, num_params = 2, varnames = 'z')
        self.logratios2 = swyft.LogRatioEstimator_Ndim(num_features = 2, marginals = ((0, 1),), varnames = 'z')

    def forward(self, A, B):
        embedding = self.embedding(A['x'])
        logratios1 = self.logratios1(embedding, B['z'])
        logratios2 = self.logratios2(embedding, B['z'])
        return logratios1, logratios2

4. Train the model#

[6]:
trainer = swyft.SwyftTrainer(accelerator = DEVICE)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/trainer/setup.py:200: UserWarning: MPS available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='mps', devices=1)`.
  rank_zero_warn(
[7]:
dm = swyft.SwyftDataModule(samples, batch_size = 64)
[8]:
network = Network()
trainer.fit(network, dm)
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py:94: PossibleUserWarning: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
  rank_zero_warn(
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint

  | Name       | Type                   | Params
------------------------------------------------------
0 | embedding  | Linear                 | 22
1 | logratios1 | LogRatioEstimator_1dim | 34.9 K
2 | logratios2 | LogRatioEstimator_Ndim | 17.5 K
------------------------------------------------------
52.5 K    Trainable params
0         Non-trainable params
52.5 K    Total params
0.210     Total estimated model params size (MB)
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Reloading best model: /Users/cweniger/Documents/swyft/notebooks/lightning_logs/version_58/checkpoints/epoch=38-step=4875.ckpt

5. Visualize training#

N/A

6. Perform validation tests#

[10]:
test_samples = sim.sample(N = 1000)
trainer.test(network, test_samples.get_dataloader(batch_size = 64))
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: EarlyStopping, ModelCheckpoint
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           -3.4464657306671143
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[10]:
[{'test_loss': -3.4464657306671143}]
[11]:
B = samples[:1000]
A = samples[:1000]
mass = trainer.test_coverage(network, A, B)
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: EarlyStopping, ModelCheckpoint
WARNING: This estimates the mass of highest-likelihood intervals.
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py:173: UserWarning: Lightning couldn't infer the indices fetched for your dataloader.
  warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: EarlyStopping, ModelCheckpoint

7. Generate posteriors#

[12]:
z0 = np.array([0.3, 0.7])
x0 = sim.sample(conditions = {"z": z0})['x']
plt.plot(x0)
[12]:
[<matplotlib.lines.Line2D at 0x1696ea4c0>]
../_images/notebooks_00_-_Swyft_in_15_minutes_20_1.png
[13]:
prior_samples = sim.sample(targets = ['z'], N = 100000)
[14]:
predictions = trainer.infer(network, swyft.Sample(x = x0), prior_samples)
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: EarlyStopping, ModelCheckpoint
[ ]:
predictions[0].parnames
array([['z[0]'],
       ['z[1]']], dtype='<U4')
[19]:
swyft.plot_posterior(predictions, ['z[0]', 'z[1]']);
../_images/notebooks_00_-_Swyft_in_15_minutes_24_0.png
[ ]:

This page was generated from notebooks/00 - Swyft in 15 minutes.ipynb.