This page was generated from notebooks/0F-Logging.ipynb.

Logging

First we need some imports.

[1]:
%load_ext autoreload
%autoreload 2
[2]:
import numpy as np
from scipy import stats
import pylab as plt
import torch
import torchist
import swyft
import pytorch_lightning as pl

Training data

Now we generate training data. As simple example, we consider the model

\[x = z + \epsilon\]

where the parameter \(z \sim \mathcal{N}(\mu = 0, \sigma = 1)\) is standard normal distributed, and \(\epsilon \sim \mathcal{N}(\mu = 0, \sigma = 0.1)\) is a small noise contribution. We are interested in the posterior of \(z\) given a measurement of parameter \(x\).

[3]:
class Simulator(swyft.Simulator):
    def __init__(self):
        super().__init__()
        self.transform_samples = swyft.to_numpy32

    def build(self, graph):
        z = graph.node('z', lambda: np.random.rand(1))
        x = graph.node('x', lambda z: z + np.random.randn(1)*0.1, z)

sim = Simulator()
[4]:
samples = sim.sample(10000)
100%|██████████| 10000/10000 [00:00<00:00, 71467.83it/s]
[5]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        self.logratios = swyft.LogRatioEstimator_1dim(num_features = 1, num_params = 1, varnames = 'z')

    def forward(self, A, B):
        logratios = self.logratios(A['x'], B['z'])
        return logratios

Trainer

Training is now done using the SwyftTrainer class, which extends pytorch_lightning.Trainer by methods like infer (see below).

[6]:
logger = pl.loggers.TensorBoardLogger("./lightning_logs", name = "Test1")
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor = 'val_loss', save_top_k = 1)
callbacks = [
    pl.callbacks.LearningRateMonitor(),
    pl.callbacks.EarlyStopping("val_loss", patience = 3),
    model_checkpoint
]
trainer = swyft.SwyftTrainer(accelerator = 'gpu', max_epochs = 100, precision = 64, logger = logger, callbacks = callbacks)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

The swyft.Samples class provides convenience functions to generate data loaders for training and validation data.

[7]:
dm = swyft.SwyftDataModule(samples, fractions = [0.8, 0.02, 0.1], num_workers = 3, batch_size = 128)
[8]:
network = Network()
[9]:
trainer.fit(network, dm)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params
-----------------------------------------------------
0 | logratios | LogRatioEstimator_1dim | 17.4 K
-----------------------------------------------------
17.4 K    Trainable params
0         Non-trainable params
17.4 K    Total params
0.139     Total estimated model params size (MB)
Epoch 0:  97%|█████████▋| 68/70 [00:01<00:00, 66.22it/s, loss=-0.543, v_num=53]
Validation: 0it [00:00, ?it/s]
Validation:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 0:  99%|█████████▊| 69/70 [00:01<00:00, 53.78it/s, loss=-0.543, v_num=53]
Epoch 0: 100%|██████████| 70/70 [00:01<00:00, 53.00it/s, loss=-0.543, v_num=53, val_loss=-.526]
Epoch 1:  97%|█████████▋| 68/70 [00:01<00:00, 65.55it/s, loss=-0.587, v_num=53, val_loss=-.526]
Validation: 0it [00:00, ?it/s]
Validation:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 1:  99%|█████████▊| 69/70 [00:01<00:00, 52.48it/s, loss=-0.587, v_num=53, val_loss=-.526]
Epoch 1: 100%|██████████| 70/70 [00:01<00:00, 51.18it/s, loss=-0.587, v_num=53, val_loss=-.526]
Epoch 2:  97%|█████████▋| 68/70 [00:01<00:00, 61.50it/s, loss=-0.571, v_num=53, val_loss=-.526]
Validation: 0it [00:00, ?it/s]
Validation:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 2:  99%|█████████▊| 69/70 [00:01<00:00, 46.57it/s, loss=-0.571, v_num=53, val_loss=-.526]
Epoch 2: 100%|██████████| 70/70 [00:01<00:00, 46.88it/s, loss=-0.571, v_num=53, val_loss=-.529]
Epoch 3:  97%|█████████▋| 68/70 [00:01<00:00, 61.21it/s, loss=-0.575, v_num=53, val_loss=-.529]
Validation: 0it [00:00, ?it/s]
Validation:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 3:  99%|█████████▊| 69/70 [00:01<00:00, 44.77it/s, loss=-0.575, v_num=53, val_loss=-.529]
Epoch 3: 100%|██████████| 70/70 [00:01<00:00, 45.15it/s, loss=-0.575, v_num=53, val_loss=-.517]
Epoch 4:  97%|█████████▋| 68/70 [00:01<00:00, 62.55it/s, loss=-0.569, v_num=53, val_loss=-.517]
Validation: 0it [00:00, ?it/s]
Validation:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 4:  99%|█████████▊| 69/70 [00:01<00:00, 49.84it/s, loss=-0.569, v_num=53, val_loss=-.517]
Epoch 4: 100%|██████████| 70/70 [00:01<00:00, 49.35it/s, loss=-0.569, v_num=53, val_loss=-.535]
Epoch 5:  97%|█████████▋| 68/70 [00:01<00:00, 64.67it/s, loss=-0.553, v_num=53, val_loss=-.535]
Validation: 0it [00:00, ?it/s]
Validation:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 5:  99%|█████████▊| 69/70 [00:01<00:00, 50.98it/s, loss=-0.553, v_num=53, val_loss=-.535]
Epoch 5: 100%|██████████| 70/70 [00:01<00:00, 51.08it/s, loss=-0.553, v_num=53, val_loss=-.547]
Epoch 6:  97%|█████████▋| 68/70 [00:01<00:00, 55.86it/s, loss=-0.565, v_num=53, val_loss=-.547]
Validation: 0it [00:00, ?it/s]
Validation:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 6:  99%|█████████▊| 69/70 [00:01<00:00, 43.95it/s, loss=-0.565, v_num=53, val_loss=-.547]
Epoch 6: 100%|██████████| 70/70 [00:01<00:00, 44.27it/s, loss=-0.565, v_num=53, val_loss=-.534]
Epoch 7:  97%|█████████▋| 68/70 [00:01<00:00, 61.57it/s, loss=-0.573, v_num=53, val_loss=-.534]
Validation: 0it [00:00, ?it/s]
Validation:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 7:  99%|█████████▊| 69/70 [00:01<00:00, 49.25it/s, loss=-0.573, v_num=53, val_loss=-.534]
Epoch 7: 100%|██████████| 70/70 [00:01<00:00, 49.40it/s, loss=-0.573, v_num=53, val_loss=-.520]
Epoch 8:  97%|█████████▋| 68/70 [00:01<00:00, 62.81it/s, loss=-0.573, v_num=53, val_loss=-.520]
Validation: 0it [00:00, ?it/s]
Validation:   0%|          | 0/2 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Epoch 8:  99%|█████████▊| 69/70 [00:01<00:00, 46.57it/s, loss=-0.573, v_num=53, val_loss=-.520]
Epoch 8: 100%|██████████| 70/70 [00:01<00:00, 46.99it/s, loss=-0.573, v_num=53, val_loss=-.539]
Epoch 8: 100%|██████████| 70/70 [00:01<00:00, 46.34it/s, loss=-0.573, v_num=53, val_loss=-.539]
[10]:
model_checkpoint.to_yaml("./test4.yaml")
[11]:
ckpt_path = swyft.best_from_yaml("./test4.yaml")
[12]:
trainer.test(network, dm, ckpt_path = ckpt_path)
Restoring states from the checkpoint path at ./lightning_logs/Test1/version_53/checkpoints/epoch=5-step=408.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./lightning_logs/Test1/version_53/checkpoints/epoch=5-step=408.ckpt
Testing DataLoader 0: 100%|██████████| 9/9 [00:00<00:00, 100.01it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_loss             -0.5928590465564723    │
└───────────────────────────┴───────────────────────────┘
[12]:
[{'test_loss': -0.5928590465564723}]
[ ]:

This page was generated from notebooks/0F-Logging.ipynb.