This page was generated from notebooks/Swyft in 15 minutes.ipynb.

[1]:
%load_ext autoreload
%autoreload 2

Swyft in 15 Minutes

We discuss seven key steps of a typical Swyft workflow.

1. Installing Swyft

We can use pip to install the lightning branch (latest development branch) of Swyft.

[2]:
#!pip install https://github.com/undark-lab/swyft.git@lightning
[3]:
import numpy as np
from scipy import stats
import pylab as plt
import torch
import swyft

2. Define the Simulator

Next we define a simulator class, which specifies the computational graph of our simulator.

[4]:
class Simulator(swyft.Simulator):
    def __init__(self):
        super().__init__()
        self.transform_samples = swyft.to_numpy32
        self.x = np.linspace(-1, 1, 10)

    def build(self, graph):
        z = graph.node('z', lambda: np.random.rand(2)*2-1)
        f = graph.node('f', lambda z: z[0] + z[1]*self.x, z)
        x = graph.node('x', lambda f: f + np.random.randn(10)*0.1, f)

sim = Simulator()
samples = sim.sample(N = 10000)
100%|██████████| 10000/10000 [00:00<00:00, 40216.54it/s]

3. Define the SwyftModule

[5]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Linear(10, 2)
        self.logratios1 = swyft.LogRatioEstimator_1dim(num_features = 2, num_params = 2, varnames = 'z')
        self.logratios2 = swyft.LogRatioEstimator_Ndim(num_features = 2, marginals = ((0, 1),), varnames = 'z')

    def forward(self, A, B):
        embedding = self.embedding(A['x'])
        logratios1 = self.logratios1(embedding, B['z'])
        logratios2 = self.logratios2(embedding, B['z'])
        return logratios1, logratios2

4. Train the model

[6]:
trainer = swyft.SwyftTrainer(accelerator = 'gpu', max_epochs = 2)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[7]:
dm = swyft.SwyftDataModule(samples, fractions = [0.8, 0.1, 0.1], batch_size = 64, num_workers = 3)
[8]:
network = Network()
trainer.fit(network, dm)
/home/weniger/miniconda3b/envs/zero/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:616: UserWarning: Checkpoint directory /home/weniger/codes/swyft/notebooks/lightning_logs/version_10034687/checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                   | Params
------------------------------------------------------
0 | embedding  | Linear                 | 22
1 | logratios1 | LogRatioEstimator_1dim | 34.9 K
2 | logratios2 | LogRatioEstimator_Ndim | 17.5 K
------------------------------------------------------
52.5 K    Trainable params
0         Non-trainable params
52.5 K    Total params
0.210     Total estimated model params size (MB)
Epoch 0:  89%|████████▊ | 125/141 [00:02<00:00, 57.59it/s, loss=-2.99, v_num=1e+7]
Validation: 0it [00:00, ?it/s]
Validation:   0%|          | 0/16 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/16 [00:00<?, ?it/s]
Epoch 0:  89%|████████▉ | 126/141 [00:02<00:00, 49.33it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  90%|█████████ | 127/141 [00:02<00:00, 49.56it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  91%|█████████ | 128/141 [00:02<00:00, 49.83it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  91%|█████████▏| 129/141 [00:02<00:00, 50.11it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  92%|█████████▏| 130/141 [00:02<00:00, 50.37it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  93%|█████████▎| 131/141 [00:02<00:00, 50.65it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  94%|█████████▎| 132/141 [00:02<00:00, 50.91it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  94%|█████████▍| 133/141 [00:02<00:00, 51.17it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  95%|█████████▌| 134/141 [00:02<00:00, 51.43it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  96%|█████████▌| 135/141 [00:02<00:00, 51.69it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  96%|█████████▋| 136/141 [00:02<00:00, 51.96it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  97%|█████████▋| 137/141 [00:02<00:00, 52.21it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  98%|█████████▊| 138/141 [00:02<00:00, 52.47it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  99%|█████████▊| 139/141 [00:02<00:00, 52.73it/s, loss=-2.99, v_num=1e+7]
Epoch 0:  99%|█████████▉| 140/141 [00:02<00:00, 52.97it/s, loss=-2.99, v_num=1e+7]
Epoch 0: 100%|██████████| 141/141 [00:02<00:00, 53.16it/s, loss=-2.99, v_num=1e+7, val_loss=-3.03]
Epoch 1:  89%|████████▊ | 125/141 [00:02<00:00, 56.29it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Validation: 0it [00:00, ?it/s]
Validation:   0%|          | 0/16 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/16 [00:00<?, ?it/s]
Epoch 1:  89%|████████▉ | 126/141 [00:02<00:00, 49.61it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  90%|█████████ | 127/141 [00:02<00:00, 49.86it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  91%|█████████ | 128/141 [00:02<00:00, 50.07it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  91%|█████████▏| 129/141 [00:02<00:00, 50.35it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  92%|█████████▏| 130/141 [00:02<00:00, 50.64it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  93%|█████████▎| 131/141 [00:02<00:00, 50.90it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  94%|█████████▎| 132/141 [00:02<00:00, 51.18it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  94%|█████████▍| 133/141 [00:02<00:00, 51.44it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  95%|█████████▌| 134/141 [00:02<00:00, 51.70it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  96%|█████████▌| 135/141 [00:02<00:00, 51.96it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  96%|█████████▋| 136/141 [00:02<00:00, 52.23it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  97%|█████████▋| 137/141 [00:02<00:00, 52.49it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  98%|█████████▊| 138/141 [00:02<00:00, 52.75it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  99%|█████████▊| 139/141 [00:02<00:00, 53.00it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1:  99%|█████████▉| 140/141 [00:02<00:00, 53.26it/s, loss=-3.11, v_num=1e+7, val_loss=-3.03]
Epoch 1: 100%|██████████| 141/141 [00:02<00:00, 53.44it/s, loss=-3.11, v_num=1e+7, val_loss=-3.11]
Epoch 1: 100%|██████████| 141/141 [00:02<00:00, 53.38it/s, loss=-3.11, v_num=1e+7, val_loss=-3.11]
`Trainer.fit` stopped: `max_epochs=2` reached.
Epoch 1: 100%|██████████| 141/141 [00:02<00:00, 52.81it/s, loss=-3.11, v_num=1e+7, val_loss=-3.11]

5. Visualize training

N/A

6. Perform validation tests

[9]:
trainer.test(network, dm)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Testing DataLoader 0: 100%|██████████| 16/16 [00:00<00:00, 162.68it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_loss             -3.0958783626556396    │
└───────────────────────────┴───────────────────────────┘
[9]:
[{'test_loss': -3.0958783626556396}]
[10]:
B = samples[:1000]
A = samples[:1000]
mass = trainer.test_coverage(network, A, B)
WARNING: This estimates the mass of highest-likelihood intervals.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 32/32 [00:00<00:00, 202.16it/s]
/home/weniger/miniconda3b/envs/zero/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py:173: UserWarning: Lightning couldn't infer the indices fetched for your dataloader.
  warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 1000/1000 [00:13<00:00, 72.18it/s]

7. Generate posteriors

[11]:
z0 = np.array([0.3, 0.7])
x0 = sim.sample(conditions = {"z": z0})['x']
plt.plot(x0)
[11]:
[<matplotlib.lines.Line2D at 0x14c75e46e970>]
../_images/notebooks_Swyft_in_15_minutes_21_1.png
[12]:
prior_samples = sim.sample(targets = ['z'], N = 100000)
100%|██████████| 100000/100000 [00:01<00:00, 98539.19it/s]
[13]:
predictions = trainer.infer(network, swyft.Sample(x = x0), prior_samples)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 98/98 [00:00<00:00, 127.59it/s]
[14]:
predictions[0].parnames
[14]:
array([['z[0]'],
       ['z[1]']], dtype='<U4')
[15]:
swyft.plot_1d(predictions, 'z[0]', ax = plt.gca())
../_images/notebooks_Swyft_in_15_minutes_25_0.png
[16]:
swyft.plot_1d(predictions, 'z[1]', ax = plt.gca())
../_images/notebooks_Swyft_in_15_minutes_26_0.png
[ ]:

This page was generated from notebooks/Swyft in 15 minutes.ipynb.