This page was generated from notebooks/00 - Swyft in 15 minutes.ipynb.

Swyft in 15 Minutes#

We discuss seven key steps of a typical Swyft workflow.

1. Installing Swyft#

We can use pip to install the lightning branch (latest development branch) of Swyft.

[2]:
#!pip install https://github.com/undark-lab/swyft.git
[3]:
import numpy as np
from scipy import stats
import pylab as plt
import torch
import swyft

2. Define the Simulator#

Next we define a simulator class, which specifies the computational graph of our simulator.

[4]:
class Simulator(swyft.Simulator):
    def __init__(self):
        super().__init__()
        self.transform_samples = swyft.to_numpy32
        self.x = np.linspace(-1, 1, 10)

    def build(self, graph):
        z = graph.node('z', lambda: np.random.rand(2)*2-1)
        f = graph.node('f', lambda z: z[0] + z[1]*self.x, z)
        x = graph.node('x', lambda f: f + np.random.randn(10)*0.1, f)

sim = Simulator()
samples = sim.sample(N = 10000)

3. Define the SwyftModule#

[5]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Linear(10, 2)
        self.logratios1 = swyft.LogRatioEstimator_1dim(num_features = 2, num_params = 2, varnames = 'z')
        self.logratios2 = swyft.LogRatioEstimator_Ndim(num_features = 2, marginals = ((0, 1),), varnames = 'z')

    def forward(self, A, B):
        embedding = self.embedding(A['x'])
        logratios1 = self.logratios1(embedding, B['z'])
        logratios2 = self.logratios2(embedding, B['z'])
        return logratios1, logratios2

4. Train the model#

[6]:
trainer = swyft.SwyftTrainer(accelerator = 'gpu', max_epochs = 2)
/home/weniger/.conda/envs/lensing/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:166: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/weniger/.conda/envs/lensing/lib/python3.9/site ...
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[7]:
dm = swyft.SwyftDataModule(samples, fractions = [0.8, 0.1, 0.1], batch_size = 64, num_workers = 3)
[8]:
network = Network()
trainer.fit(network, dm)
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                   | Params
------------------------------------------------------
0 | embedding  | Linear                 | 22
1 | logratios1 | LogRatioEstimator_1dim | 34.9 K
2 | logratios2 | LogRatioEstimator_Ndim | 17.5 K
------------------------------------------------------
52.5 K    Trainable params
0         Non-trainable params
52.5 K    Total params
0.210     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=2` reached.

5. Visualize training#

N/A

6. Perform validation tests#

[9]:
trainer.test(network, dm)
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           -3.115851640701294
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[9]:
[{'test_loss': -3.115851640701294}]
[10]:
B = samples[:1000]
A = samples[:1000]
mass = trainer.test_coverage(network, A, B)
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
WARNING: This estimates the mass of highest-likelihood intervals.
/home/weniger/.conda/envs/lensing/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py:173: UserWarning: Lightning couldn't infer the indices fetched for your dataloader.
  warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

7. Generate posteriors#

[11]:
z0 = np.array([0.3, 0.7])
x0 = sim.sample(conditions = {"z": z0})['x']
plt.plot(x0)
[11]:
[<matplotlib.lines.Line2D at 0x7fc899900b20>]
../_images/notebooks_00_-_Swyft_in_15_minutes_20_1.png
[12]:
prior_samples = sim.sample(targets = ['z'], N = 100000)
[13]:
predictions = trainer.infer(network, swyft.Sample(x = x0), prior_samples)
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[14]:
predictions[0].parnames
[14]:
array([['z[0]'],
       ['z[1]']], dtype='<U4')
[17]:
swyft.plot_1d(predictions, 'z[0]')
../_images/notebooks_00_-_Swyft_in_15_minutes_24_0.png
[18]:
swyft.plot_1d(predictions, 'z[1]')
../_images/notebooks_00_-_Swyft_in_15_minutes_25_0.png
[ ]:

This page was generated from notebooks/00 - Swyft in 15 minutes.ipynb.