This page was generated from notebooks/0E-ZarrStore.ipynb.

Storing training data on disk via Zarr

First we need some imports.

[1]:
%load_ext autoreload
%autoreload 2
[2]:
import numpy as np
from scipy import stats
import pylab as plt
import torch
import torchist
import swyft

Training data

Now we generate training data. As simple example, we consider the model

\[x = z + \epsilon\]

where the parameter \(z \sim \mathcal{N}(\mu = 0, \sigma = 1)\) is standard normal distributed, and \(\epsilon \sim \mathcal{N}(\mu = 0, \sigma = 0.1)\) is a small noise contribution. We are interested in the posterior of \(z\) given a measurement of parameter \(x\).

[3]:
class Simulator(swyft.Simulator):
    def __init__(self):
        super().__init__()
        self.transform_samples = swyft.to_numpy32

    def build(self, graph):
        z = graph.node('z', lambda: np.random.rand(1))
        x = graph.node('x', lambda z: z + np.random.randn(1)*0.1, z)

sim = Simulator()
shapes, dtypes = sim.get_shapes_and_dtypes()
[4]:
store = swyft.ZarrStore("./zarr_store666")
store.init(10000, 64, shapes, dtypes)
WARNING: Already initialized.
[4]:
<swyft.lightning.stores.ZarrStore at 0x1463d9eeb0d0>
[5]:
store.simulate(sim, batch_size = 1000)  # This function can be run in parallel in many threads
[6]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        self.logratios = swyft.LogRatioEstimator_1dim(num_features = 1, num_params = 1, varnames = 'z')

    def forward(self, A, B):
        logratios = self.logratios(A['x'], B['z'])
        return logratios

Trainer

Training is now done using the SwyftTrainer class, which extends pytorch_lightning.Trainer by methods like infer (see below).

[10]:
trainer = swyft.SwyftTrainer(accelerator = 'gpu', devices=1, max_epochs = 2, precision = 64)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

The swyft.Samples class provides convenience functions to generate data loaders for training and validation data.

[11]:
dm = swyft.SwyftDataModule(store, fractions = [0.8, 0.1, 0.1], batch_size = 32)
[12]:
network = Network()
[13]:
trainer.fit(network, dm)
/home/weniger/miniconda3b/envs/zero/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:616: UserWarning: Checkpoint directory /home/weniger/codes/swyft/notebooks/lightning_logs/version_10036357/checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params
-----------------------------------------------------
0 | logratios | LogRatioEstimator_1dim | 17.4 K
-----------------------------------------------------
17.4 K    Trainable params
0         Non-trainable params
17.4 K    Total params
0.139     Total estimated model params size (MB)
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
/home/weniger/miniconda3b/envs/zero/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:219: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 24 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

/home/weniger/miniconda3b/envs/zero/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:219: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 24 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Epoch 0: : 32it [00:00, 77.04it/s, loss=-0.494, v_num=1e+7]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation DataLoader 0: : 0it [00:00, ?it/s]
Epoch 0: : 33it [00:00, 76.08it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 34it [00:00, 76.88it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 35it [00:00, 78.22it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 36it [00:00, 78.91it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 37it [00:00, 80.10it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 38it [00:00, 80.78it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 39it [00:00, 81.95it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 40it [00:00, 82.44it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 41it [00:00, 83.48it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 42it [00:00, 83.99it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 43it [00:00, 85.01it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 44it [00:00, 85.45it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 45it [00:00, 86.35it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 46it [00:00, 86.65it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 47it [00:00, 87.62it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 48it [00:00, 87.94it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 49it [00:00, 88.75it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 50it [00:00, 89.15it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 51it [00:00, 90.04it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 52it [00:00, 90.32it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 53it [00:00, 91.09it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 54it [00:00, 91.37it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 55it [00:00, 92.18it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 56it [00:00, 92.39it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 57it [00:00, 93.09it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 58it [00:00, 93.48it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 59it [00:00, 94.42it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 60it [00:00, 94.80it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 61it [00:00, 95.65it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 62it [00:00, 95.98it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 63it [00:00, 96.68it/s, loss=-0.494, v_num=1e+7]
Epoch 0: : 64it [00:00, 93.70it/s, loss=-0.494, v_num=1e+7, val_loss=-.105]
Epoch 1: : 32it [00:00, 81.83it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation DataLoader 0: : 0it [00:00, ?it/s]
Epoch 1: : 33it [00:00, 80.67it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 34it [00:00, 81.37it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 35it [00:00, 82.76it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 36it [00:00, 83.44it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 37it [00:00, 84.70it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 38it [00:00, 85.33it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 39it [00:00, 86.63it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 40it [00:00, 87.20it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 41it [00:00, 88.32it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 42it [00:00, 88.72it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 43it [00:00, 89.80it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 44it [00:00, 90.29it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 45it [00:00, 91.36it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 46it [00:00, 91.79it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 47it [00:00, 92.71it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 48it [00:00, 93.13it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 49it [00:00, 94.22it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 50it [00:00, 94.60it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 51it [00:00, 95.57it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 52it [00:00, 95.85it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 53it [00:00, 96.69it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 54it [00:00, 96.69it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 55it [00:00, 97.58it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 56it [00:00, 97.79it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 57it [00:00, 98.46it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 58it [00:00, 98.73it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 59it [00:00, 99.55it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 60it [00:00, 99.52it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 61it [00:00, 100.37it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 62it [00:00, 100.74it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 63it [00:00, 101.59it/s, loss=-0.544, v_num=1e+7, val_loss=-.105]
Epoch 1: : 64it [00:00, 102.25it/s, loss=-0.544, v_num=1e+7, val_loss=-.496]
Epoch 1: : 64it [00:00, 101.83it/s, loss=-0.544, v_num=1e+7, val_loss=-.496]
`Trainer.fit` stopped: `max_epochs=2` reached.
Epoch 1: : 64it [00:00, 99.26it/s, loss=-0.544, v_num=1e+7, val_loss=-.496]
[ ]:

This page was generated from notebooks/0E-ZarrStore.ipynb.