This page was generated from notebooks/0F-Logging.ipynb.
Logging¶
First we need some imports.
[1]:
%load_ext autoreload
%autoreload 2
[2]:
import numpy as np
from scipy import stats
import pylab as plt
import torch
import torchist
import swyft
import pytorch_lightning as pl
Training data¶
Now we generate training data. As simple example, we consider the model
\[x = z + \epsilon\]
where the parameter \(z \sim \mathcal{N}(\mu = 0, \sigma = 1)\) is standard normal distributed, and \(\epsilon \sim \mathcal{N}(\mu = 0, \sigma = 0.1)\) is a small noise contribution. We are interested in the posterior of \(z\) given a measurement of parameter \(x\).
[3]:
class Simulator(swyft.Simulator):
def __init__(self):
super().__init__()
self.transform_samples = swyft.to_numpy32
def build(self, graph):
z = graph.node('z', lambda: np.random.rand(1))
x = graph.node('x', lambda z: z + np.random.randn(1)*0.1, z)
sim = Simulator()
[4]:
samples = sim.sample(10000)
100%|██████████| 10000/10000 [00:00<00:00, 71467.83it/s]
[5]:
class Network(swyft.SwyftModule):
def __init__(self):
super().__init__()
self.logratios = swyft.LogRatioEstimator_1dim(num_features = 1, num_params = 1, varnames = 'z')
def forward(self, A, B):
logratios = self.logratios(A['x'], B['z'])
return logratios
Trainer¶
Training is now done using the SwyftTrainer
class, which extends pytorch_lightning.Trainer
by methods like infer
(see below).
[6]:
logger = pl.loggers.TensorBoardLogger("./lightning_logs", name = "Test1")
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor = 'val_loss', save_top_k = 1)
callbacks = [
pl.callbacks.LearningRateMonitor(),
pl.callbacks.EarlyStopping("val_loss", patience = 3),
model_checkpoint
]
trainer = swyft.SwyftTrainer(accelerator = 'gpu', max_epochs = 100, precision = 64, logger = logger, callbacks = callbacks)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
The swyft.Samples
class provides convenience functions to generate data loaders for training and validation data.
[7]:
dm = swyft.SwyftDataModule(samples, fractions = [0.8, 0.02, 0.1], num_workers = 3, batch_size = 128)
[8]:
network = Network()
[9]:
trainer.fit(network, dm)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
-----------------------------------------------------
0 | logratios | LogRatioEstimator_1dim | 17.4 K
-----------------------------------------------------
17.4 K Trainable params
0 Non-trainable params
17.4 K Total params
0.139 Total estimated model params size (MB)
Epoch 0: 97%|█████████▋| 68/70 [00:01<00:00, 66.22it/s, loss=-0.543, v_num=53]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Epoch 0: 99%|█████████▊| 69/70 [00:01<00:00, 53.78it/s, loss=-0.543, v_num=53]
Epoch 0: 100%|██████████| 70/70 [00:01<00:00, 53.00it/s, loss=-0.543, v_num=53, val_loss=-.526]
Epoch 1: 97%|█████████▋| 68/70 [00:01<00:00, 65.55it/s, loss=-0.587, v_num=53, val_loss=-.526]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Epoch 1: 99%|█████████▊| 69/70 [00:01<00:00, 52.48it/s, loss=-0.587, v_num=53, val_loss=-.526]
Epoch 1: 100%|██████████| 70/70 [00:01<00:00, 51.18it/s, loss=-0.587, v_num=53, val_loss=-.526]
Epoch 2: 97%|█████████▋| 68/70 [00:01<00:00, 61.50it/s, loss=-0.571, v_num=53, val_loss=-.526]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Epoch 2: 99%|█████████▊| 69/70 [00:01<00:00, 46.57it/s, loss=-0.571, v_num=53, val_loss=-.526]
Epoch 2: 100%|██████████| 70/70 [00:01<00:00, 46.88it/s, loss=-0.571, v_num=53, val_loss=-.529]
Epoch 3: 97%|█████████▋| 68/70 [00:01<00:00, 61.21it/s, loss=-0.575, v_num=53, val_loss=-.529]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Epoch 3: 99%|█████████▊| 69/70 [00:01<00:00, 44.77it/s, loss=-0.575, v_num=53, val_loss=-.529]
Epoch 3: 100%|██████████| 70/70 [00:01<00:00, 45.15it/s, loss=-0.575, v_num=53, val_loss=-.517]
Epoch 4: 97%|█████████▋| 68/70 [00:01<00:00, 62.55it/s, loss=-0.569, v_num=53, val_loss=-.517]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Epoch 4: 99%|█████████▊| 69/70 [00:01<00:00, 49.84it/s, loss=-0.569, v_num=53, val_loss=-.517]
Epoch 4: 100%|██████████| 70/70 [00:01<00:00, 49.35it/s, loss=-0.569, v_num=53, val_loss=-.535]
Epoch 5: 97%|█████████▋| 68/70 [00:01<00:00, 64.67it/s, loss=-0.553, v_num=53, val_loss=-.535]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Epoch 5: 99%|█████████▊| 69/70 [00:01<00:00, 50.98it/s, loss=-0.553, v_num=53, val_loss=-.535]
Epoch 5: 100%|██████████| 70/70 [00:01<00:00, 51.08it/s, loss=-0.553, v_num=53, val_loss=-.547]
Epoch 6: 97%|█████████▋| 68/70 [00:01<00:00, 55.86it/s, loss=-0.565, v_num=53, val_loss=-.547]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Epoch 6: 99%|█████████▊| 69/70 [00:01<00:00, 43.95it/s, loss=-0.565, v_num=53, val_loss=-.547]
Epoch 6: 100%|██████████| 70/70 [00:01<00:00, 44.27it/s, loss=-0.565, v_num=53, val_loss=-.534]
Epoch 7: 97%|█████████▋| 68/70 [00:01<00:00, 61.57it/s, loss=-0.573, v_num=53, val_loss=-.534]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Epoch 7: 99%|█████████▊| 69/70 [00:01<00:00, 49.25it/s, loss=-0.573, v_num=53, val_loss=-.534]
Epoch 7: 100%|██████████| 70/70 [00:01<00:00, 49.40it/s, loss=-0.573, v_num=53, val_loss=-.520]
Epoch 8: 97%|█████████▋| 68/70 [00:01<00:00, 62.81it/s, loss=-0.573, v_num=53, val_loss=-.520]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Epoch 8: 99%|█████████▊| 69/70 [00:01<00:00, 46.57it/s, loss=-0.573, v_num=53, val_loss=-.520]
Epoch 8: 100%|██████████| 70/70 [00:01<00:00, 46.99it/s, loss=-0.573, v_num=53, val_loss=-.539]
Epoch 8: 100%|██████████| 70/70 [00:01<00:00, 46.34it/s, loss=-0.573, v_num=53, val_loss=-.539]
[10]:
model_checkpoint.to_yaml("./test4.yaml")
[11]:
ckpt_path = swyft.best_from_yaml("./test4.yaml")
[12]:
trainer.test(network, dm, ckpt_path = ckpt_path)
Restoring states from the checkpoint path at ./lightning_logs/Test1/version_53/checkpoints/epoch=5-step=408.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./lightning_logs/Test1/version_53/checkpoints/epoch=5-step=408.ckpt
Testing DataLoader 0: 100%|██████████| 9/9 [00:00<00:00, 100.01it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_loss │ -0.5928590465564723 │ └───────────────────────────┴───────────────────────────┘
[12]:
[{'test_loss': -0.5928590465564723}]
[ ]:
This page was generated from notebooks/0F-Logging.ipynb.