J - ZarrStore and Parallel Simulations#

Authors: Noemi Anau Montel, James Alvey, Christoph Weniger

Last update: 15 September 2023

Purpose: Introduction of on-disk storage of training data using Zarr.

Key take-away messages: Swyft provides efficient codes for the storage and retrieval of training data from disk.

Code#

[1]:
import numpy as np
import pylab as plt
import torch
import swyft
DEVICE = 'gpu' if torch.cuda.is_available() else 'cpu'
[2]:
torch.manual_seed(0)
np.random.seed(0)

We use our toy simulator from above.

[3]:
class Simulator(swyft.Simulator):
    def __init__(self, Nbins = 100, sigma = 0.2):
        super().__init__()
        self.transform_samples = swyft.to_numpy32
        self.Nbins = Nbins
        self.y = np.linspace(-1, 1, Nbins)
        self.sigma = sigma

    def calc_m(self, z):
        m = np.ones_like(self.y)*z[0] + self.y*z[1] + self.y**2*z[2]
        return m

    def build(self, graph):  # the print statements are for illustration only
        z = graph.node('z', lambda: np.random.rand(3)*2 - 1)
        m = graph.node('m', self.calc_m, z)
        x = graph.node('x', lambda m: m + np.random.randn(self.Nbins)*self.sigma, m)

In order to generate a storage on disk, we must have information about the types and shapes of simulation variables. Those can be collected using the get_shapes_and_dtypes method of the simulator class, which invoked the simulator once and inspects the output.

[4]:
sim = Simulator()
shapes, dtypes = sim.get_shapes_and_dtypes()
print("shapes:", shapes)
print("dtypes:", dtypes)
shapes: {'z': (3,), 'm': (100,), 'x': (100,)}
dtypes: {'z': dtype('float32'), 'm': dtype('float32'), 'x': dtype('float32')}

We can then instantiate a new Zarr store. In the below example, the store has space for 10000 simulations, and they are saved in batches of 64 simulations per file (which speeds up reading from disk). Each variable is saved in its own file. For good performance, situations with very large number of low-dimensional variables should be avoided.

[5]:
store = swyft.ZarrStore("./example_zarr_store")
store.init(10000, 64, shapes, dtypes)  # Only initializes once, afterwards generates a warning message
[5]:
<swyft.lightning.data.ZarrStore at 0x284cac640>

Running the simulator can now be done by invoking the simulate method. The batch_size indiactes how many simulations are run and aggregated before they are stored to disk in a single operation. It should be adjusted according to whether file access or simulation time is the bottleneck.

[6]:
store.simulate(sim, batch_size = 1000)  # This function can be run in parallel in many threads, if the store is full, it will do nothing

Concurrent simulations on multiple machines: the simulate method can be run in multiple threads or on multiple machines at the same time, which allows one to fill the same simulation store through parallel simulations on many machines. A file locking mechanism ensures that stored simulations are not overwritten.

Finally, a datamodule that can be used for training neural networks can be set up as before, replacing our in-memory samples simply with the store.

[7]:
dm = swyft.SwyftDataModule(store, batch_size = 32)

Training a nework is then straightforward.

[9]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Linear(100, 10)
        self.logratios = swyft.LogRatioEstimator_1dim(num_features = 10, num_params = 3, varnames = 'z')

    def forward(self, A, B):
        f = self.embedding(A['x'])
        logratios = self.logratios(f, B['z'])
        return logratios

trainer = swyft.SwyftTrainer(accelerator = DEVICE, precision = 64)
network = Network()
trainer.fit(network, dm)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/trainer/setup.py:200: UserWarning: MPS available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='mps', devices=1)`.
  rank_zero_warn(
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py:94: PossibleUserWarning: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
  rank_zero_warn(
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint

  | Name      | Type                   | Params
-----------------------------------------------------
0 | embedding | Linear                 | 1.0 K
1 | logratios | LogRatioEstimator_1dim | 54.0 K
-----------------------------------------------------
55.0 K    Trainable params
0         Non-trainable params
55.0 K    Total params
0.440     Total estimated model params size (MB)
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/Users/cweniger/opt/anaconda3/envs/native2/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Reloading best model: /Users/cweniger/Documents/swyft/notebooks/lightning_logs/version_28/checkpoints/epoch=14-step=3750.ckpt

Congratulations! You trained a network with a quite efficient Zarr-based disk storage that remains efficient even for large data volumina.

Exercises#

  1. Inspect the content of the Zarr store, in the folder ‘example_zarr_store’, on the file system. (In google colab notebooks: click on the file folder on the left, navigate to /content/example_zarr_store). In which folder is the simulation data stored? What is the number of files per variable? How does it related to the numbers specified during initialization?

[ ]:
# Results
  1. Initialize a second zarr store with a different chunck size (don’t make it too small). Does this affect the number of files generated during simulation as expected?

[ ]:
# Results