This page was generated from notebooks/00 - Swyft in 15 minutes.ipynb.
Swyft in 15 Minutes#
We discuss seven key steps of a typical Swyft workflow.
1. Installing Swyft#
We can use pip
to install the lightning branch (latest development branch) of Swyft.
[2]:
#!pip install https://github.com/undark-lab/swyft.git
[3]:
import numpy as np
from scipy import stats
import pylab as plt
import torch
import swyft
2. Define the Simulator#
Next we define a simulator class, which specifies the computational graph of our simulator.
[4]:
class Simulator(swyft.Simulator):
def __init__(self):
super().__init__()
self.transform_samples = swyft.to_numpy32
self.x = np.linspace(-1, 1, 10)
def build(self, graph):
z = graph.node('z', lambda: np.random.rand(2)*2-1)
f = graph.node('f', lambda z: z[0] + z[1]*self.x, z)
x = graph.node('x', lambda f: f + np.random.randn(10)*0.1, f)
sim = Simulator()
samples = sim.sample(N = 10000)
3. Define the SwyftModule#
[5]:
class Network(swyft.SwyftModule):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Linear(10, 2)
self.logratios1 = swyft.LogRatioEstimator_1dim(num_features = 2, num_params = 2, varnames = 'z')
self.logratios2 = swyft.LogRatioEstimator_Ndim(num_features = 2, marginals = ((0, 1),), varnames = 'z')
def forward(self, A, B):
embedding = self.embedding(A['x'])
logratios1 = self.logratios1(embedding, B['z'])
logratios2 = self.logratios2(embedding, B['z'])
return logratios1, logratios2
4. Train the model#
[6]:
trainer = swyft.SwyftTrainer(accelerator = 'gpu', max_epochs = 2)
/home/weniger/.conda/envs/lensing/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:166: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/weniger/.conda/envs/lensing/lib/python3.9/site ...
rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[7]:
dm = swyft.SwyftDataModule(samples, fractions = [0.8, 0.1, 0.1], batch_size = 64, num_workers = 3)
[8]:
network = Network()
trainer.fit(network, dm)
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
------------------------------------------------------
0 | embedding | Linear | 22
1 | logratios1 | LogRatioEstimator_1dim | 34.9 K
2 | logratios2 | LogRatioEstimator_Ndim | 17.5 K
------------------------------------------------------
52.5 K Trainable params
0 Non-trainable params
52.5 K Total params
0.210 Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=2` reached.
5. Visualize training#
N/A
6. Perform validation tests#
[9]:
trainer.test(network, dm)
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test_loss -3.115851640701294
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[9]:
[{'test_loss': -3.115851640701294}]
[10]:
B = samples[:1000]
A = samples[:1000]
mass = trainer.test_coverage(network, A, B)
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
WARNING: This estimates the mass of highest-likelihood intervals.
/home/weniger/.conda/envs/lensing/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py:173: UserWarning: Lightning couldn't infer the indices fetched for your dataloader.
warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
7. Generate posteriors#
[11]:
z0 = np.array([0.3, 0.7])
x0 = sim.sample(conditions = {"z": z0})['x']
plt.plot(x0)
[11]:
[<matplotlib.lines.Line2D at 0x7fc899900b20>]
[12]:
prior_samples = sim.sample(targets = ['z'], N = 100000)
[13]:
predictions = trainer.infer(network, swyft.Sample(x = x0), prior_samples)
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[14]:
predictions[0].parnames
[14]:
array([['z[0]'],
['z[1]']], dtype='<U4')
[17]:
swyft.plot_1d(predictions, 'z[0]')
[18]:
swyft.plot_1d(predictions, 'z[1]')
[ ]:
This page was generated from notebooks/00 - Swyft in 15 minutes.ipynb.