This page was generated from notebooks/0G-Bounds.ipynb.
Bounds¶
First we need some imports.
[1]:
%load_ext autoreload
%autoreload 2
[2]:
import numpy as np
from scipy import stats
import pylab as plt
import torch
import torchist
import swyft
import pytorch_lightning as pl
Training data¶
Now we generate training data. As simple example, we consider the model
\[x = z + \epsilon\]
where the parameter \(z \sim \mathcal{N}(\mu = 0, \sigma = 1)\) is standard normal distributed, and \(\epsilon \sim \mathcal{N}(\mu = 0, \sigma = 0.1)\) is a small noise contribution. We are interested in the posterior of \(z\) given a measurement of parameter \(x\).
[3]:
class Simulator(swyft.Simulator):
def __init__(self, bounds = None):
super().__init__()
self.transform_samples = swyft.to_numpy32
self.z_sampler = swyft.RectBoundSampler(stats.norm([0], [1]), bounds = bounds.params[0,0] if bounds else None)
def build(self, graph):
z = graph.node('z', self.z_sampler)
x = graph.node('x', lambda z: z + np.random.randn(1)*0.1, z)
sim = Simulator(bounds = None)
[4]:
samples = sim.sample(10000)
100%|██████████| 10000/10000 [00:08<00:00, 1221.85it/s]
[5]:
class Network(swyft.SwyftModule):
def __init__(self):
super().__init__()
self.logratios = swyft.LogRatioEstimator_1dim(num_features = 1, num_params = 1, varnames = 'z')
def forward(self, A, B):
logratios = self.logratios(A['x'], B['z'])
return logratios
Trainer¶
Training is now done using the SwyftTrainer
class, which extends pytorch_lightning.Trainer
by methods like infer
(see below).
[6]:
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor = 'val_loss', save_top_k = 1)
logger = pl.loggers.TensorBoardLogger("./lightning_logs", name = "Test1")
trainer = swyft.SwyftTrainer(accelerator = 'gpu', max_epochs = 20, precision = 64, logger = logger, callbacks = [model_checkpoint])
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
The swyft.Samples
class provides convenience functions to generate data loaders for training and validation data.
[7]:
dm = swyft.SwyftDataModule(samples, fractions = [0.8, 0.02, 0.1], num_workers = 3, batch_size = 256)
[8]:
network = Network()
[9]:
trainer.fit(network, dm)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
-----------------------------------------------------
0 | logratios | LogRatioEstimator_1dim | 17.4 K
-----------------------------------------------------
17.4 K Trainable params
0 Non-trainable params
17.4 K Total params
0.139 Total estimated model params size (MB)
/home/weniger/miniconda3b/envs/zero/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1894: PossibleUserWarning: The number of training batches (34) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
rank_zero_warn(
Epoch 0: 97%|█████████▋| 34/35 [00:00<00:00, 55.22it/s, loss=-0.932, v_num=54]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 0: 100%|██████████| 35/35 [00:00<00:00, 35.75it/s, loss=-0.932, v_num=54, val_loss=-.676]
Epoch 1: 97%|█████████▋| 34/35 [00:00<00:00, 52.57it/s, loss=-0.966, v_num=54, val_loss=-.676]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 1: 100%|██████████| 35/35 [00:01<00:00, 32.47it/s, loss=-0.966, v_num=54, val_loss=-.927]
Epoch 2: 97%|█████████▋| 34/35 [00:00<00:00, 47.42it/s, loss=-0.965, v_num=54, val_loss=-.927]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 2: 100%|██████████| 35/35 [00:01<00:00, 34.45it/s, loss=-0.965, v_num=54, val_loss=-.941]
Epoch 3: 97%|█████████▋| 34/35 [00:00<00:00, 54.49it/s, loss=-0.959, v_num=54, val_loss=-.941]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 3: 100%|██████████| 35/35 [00:00<00:00, 36.50it/s, loss=-0.959, v_num=54, val_loss=-.943]
Epoch 4: 97%|█████████▋| 34/35 [00:00<00:00, 46.98it/s, loss=-0.972, v_num=54, val_loss=-.943]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 4: 100%|██████████| 35/35 [00:01<00:00, 33.03it/s, loss=-0.972, v_num=54, val_loss=-.925]
Epoch 5: 97%|█████████▋| 34/35 [00:00<00:00, 53.99it/s, loss=-0.976, v_num=54, val_loss=-.925]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 5: 100%|██████████| 35/35 [00:00<00:00, 39.04it/s, loss=-0.976, v_num=54, val_loss=-.941]
Epoch 6: 97%|█████████▋| 34/35 [00:00<00:00, 51.17it/s, loss=-0.974, v_num=54, val_loss=-.941]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 6: 100%|██████████| 35/35 [00:00<00:00, 38.60it/s, loss=-0.974, v_num=54, val_loss=-.947]
Epoch 7: 97%|█████████▋| 34/35 [00:00<00:00, 48.07it/s, loss=-0.967, v_num=54, val_loss=-.947]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 7: 100%|██████████| 35/35 [00:01<00:00, 32.99it/s, loss=-0.967, v_num=54, val_loss=-.942]
Epoch 8: 97%|█████████▋| 34/35 [00:00<00:00, 57.84it/s, loss=-0.966, v_num=54, val_loss=-.942]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 8: 100%|██████████| 35/35 [00:00<00:00, 37.35it/s, loss=-0.966, v_num=54, val_loss=-.944]
Epoch 9: 97%|█████████▋| 34/35 [00:00<00:00, 52.04it/s, loss=-0.956, v_num=54, val_loss=-.944]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 9: 100%|██████████| 35/35 [00:00<00:00, 36.01it/s, loss=-0.956, v_num=54, val_loss=-.947]
Epoch 10: 97%|█████████▋| 34/35 [00:00<00:00, 53.66it/s, loss=-0.976, v_num=54, val_loss=-.947]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 10: 100%|██████████| 35/35 [00:00<00:00, 41.58it/s, loss=-0.976, v_num=54, val_loss=-.937]
Epoch 11: 97%|█████████▋| 34/35 [00:00<00:00, 46.99it/s, loss=-0.965, v_num=54, val_loss=-.937]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 11: 100%|██████████| 35/35 [00:01<00:00, 33.30it/s, loss=-0.965, v_num=54, val_loss=-.935]
Epoch 12: 97%|█████████▋| 34/35 [00:00<00:00, 49.78it/s, loss=-0.972, v_num=54, val_loss=-.935]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 12: 100%|██████████| 35/35 [00:01<00:00, 33.83it/s, loss=-0.972, v_num=54, val_loss=-.932]
Epoch 13: 97%|█████████▋| 34/35 [00:00<00:00, 49.08it/s, loss=-0.972, v_num=54, val_loss=-.932]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 13: 100%|██████████| 35/35 [00:00<00:00, 36.06it/s, loss=-0.972, v_num=54, val_loss=-.946]
Epoch 14: 97%|█████████▋| 34/35 [00:00<00:00, 50.02it/s, loss=-0.965, v_num=54, val_loss=-.946]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 14: 100%|██████████| 35/35 [00:00<00:00, 36.56it/s, loss=-0.965, v_num=54, val_loss=-.944]
Epoch 15: 97%|█████████▋| 34/35 [00:00<00:00, 53.10it/s, loss=-0.969, v_num=54, val_loss=-.944]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 15: 100%|██████████| 35/35 [00:00<00:00, 38.36it/s, loss=-0.969, v_num=54, val_loss=-.943]
Epoch 16: 97%|█████████▋| 34/35 [00:00<00:00, 45.45it/s, loss=-0.968, v_num=54, val_loss=-.943]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 16: 100%|██████████| 35/35 [00:01<00:00, 32.26it/s, loss=-0.968, v_num=54, val_loss=-.948]
Epoch 17: 97%|█████████▋| 34/35 [00:00<00:00, 49.43it/s, loss=-0.989, v_num=54, val_loss=-.948]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 17: 100%|██████████| 35/35 [00:01<00:00, 33.66it/s, loss=-0.989, v_num=54, val_loss=-.946]
Epoch 18: 97%|█████████▋| 34/35 [00:00<00:00, 50.29it/s, loss=-0.972, v_num=54, val_loss=-.946]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 18: 100%|██████████| 35/35 [00:01<00:00, 34.25it/s, loss=-0.972, v_num=54, val_loss=-.943]
Epoch 19: 97%|█████████▋| 34/35 [00:00<00:00, 51.76it/s, loss=-0.992, v_num=54, val_loss=-.943]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Epoch 19: 100%|██████████| 35/35 [00:00<00:00, 36.06it/s, loss=-0.992, v_num=54, val_loss=-.945]
Epoch 19: 100%|██████████| 35/35 [00:00<00:00, 35.96it/s, loss=-0.992, v_num=54, val_loss=-.945]
`Trainer.fit` stopped: `max_epochs=20` reached.
Epoch 19: 100%|██████████| 35/35 [00:00<00:00, 35.84it/s, loss=-0.992, v_num=54, val_loss=-.945]
[10]:
prior_samples = sim.sample(N = 10000, targets = ['z'])
100%|██████████| 10000/10000 [00:08<00:00, 1237.94it/s]
[11]:
obs = swyft.Sample(x = np.array([0.3]))
[12]:
logratios = trainer.infer(network, obs, prior_samples)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 153.42it/s]
/home/weniger/miniconda3b/envs/zero/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py:173: UserWarning: Lightning couldn't infer the indices fetched for your dataloader.
warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
Predicting DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 155.68it/s]
[13]:
swyft.plot_1d(logratios, "z[0]", ax = plt.gca(), smooth = 1, bins = 30)
[14]:
bounds = swyft.lightning.bounds.get_rect_bounds(logratios, threshold = 1e-6)
bounds
[14]:
RectangleBounds(params=tensor([[[-1.7873, 2.2994]]], dtype=torch.float64), parnames=array([['z[0]']], dtype='<U4'))
[ ]:
This page was generated from notebooks/0G-Bounds.ipynb.