This page was generated from notebooks/Examples - 3. Saving and Loading.ipynb.

Saving and Loading

[1]:
%load_ext autoreload
%autoreload 2
[2]:
from os import environ
GOOGLE_COLAB = True if "COLAB_GPU" in environ else False
if GOOGLE_COLAB:
    !pip install git+https://github.com/undark-lab/swyft.git
[3]:
import numpy as np
import pylab as plt
import torch

import swyft
[4]:
device = "cpu"
n_training_samples = 3000
n_parameters = 2
marginal_indices_1d, marginal_indices_2d = swyft.utils.get_corner_marginal_indices(n_parameters)
observation_key = "y"

n_posterior_samples_for_truncation = 10_000
n_rejection_samples = 10_000
[5]:
def model(v):
    a, b = v
    x = np.linspace(-1, 1, 100)
    y = a*x + b
    return {observation_key: y}

v_o = np.zeros(n_parameters)
observation_o = model(v_o)

n_observation_features = observation_o[observation_key].shape[0]
observation_shapes = {key: value.shape for key, value in observation_o.items()}
[6]:
simulator = swyft.Simulator(
    model,
    n_parameters,
    sim_shapes=observation_shapes
)

low = -1 * np.ones(n_parameters)
high = 1 * np.ones(n_parameters)
prior = swyft.get_uniform_prior(low, high)

store = swyft.Store.memory_store(simulator)
# drawing samples from the store is Poisson distributed. Simulating slightly more than we need avoids attempting to draw more than we have.
store.add(n_training_samples + 0.02 * n_training_samples, prior)
store.simulate()

def noise(obs, v, sigma=0.05):
    y = obs[observation_key]
    n = np.random.randn(*y.shape)*sigma
    return {observation_key: y+n}
Creating new store.
Store: Adding 3077 new samples to simulator store.
[7]:
dataset = swyft.Dataset(
    n_training_samples,
    prior,
    store,
    simhook=noise
)
[8]:
for i in range(10):
    plt.plot(dataset[i][0][observation_key])
../_images/notebooks_Examples_-_3._Saving_and_Loading_8_0.png
[9]:
network_1d = swyft.get_marginal_classifier(
    observation_key=observation_key,
    marginal_indices=marginal_indices_1d,
    observation_shapes=observation_shapes,
    n_parameters=n_parameters,
    hidden_features=32,
    num_blocks=2,
)
mre_1d = swyft.MarginalRatioEstimator(
    marginal_indices=marginal_indices_1d,
    network=network_1d,
    device=device,
)
[10]:
mre_1d.train(dataset)
training: lr=5e-05, epoch=25, validation loss=0.41132
[11]:
posterior_1d = swyft.MarginalPosterior(mre_1d, prior)
samples_1d = posterior_1d.sample(n_rejection_samples, observation_o)
[12]:
_ = swyft.hist1d(samples_1d, kde=True, truth=v_o)
../_images/notebooks_Examples_-_3._Saving_and_Loading_12_0.png

Saving

Saving is straightforward

[13]:
prior_filename = "example3.prior.pt"
dataset_filename = "examples3.dataset.pt"
mre_1d_filename = "examples3.mre_1d.pt"

prior.save(prior_filename)
dataset.save(dataset_filename)
mre_1d.save(mre_1d_filename)

Loading

Loading is different for different object types. The prior and dataset are simple, but the network is a bit more involved. One must first instantiate a network, then load the saved version into it.

[14]:
prior_loaded = swyft.Prior.load(prior_filename)
dataset_loaded = swyft.Dataset.load(
    filename=dataset_filename,
    store=store,
    simhook = noise
)


network_new = swyft.get_marginal_classifier(
    observation_key=observation_key,
    marginal_indices=marginal_indices_1d,
    observation_shapes=observation_shapes,
    n_parameters=n_parameters,
    hidden_features=32,
    num_blocks=2,
)

mre_1d_loaded = swyft.MarginalRatioEstimator.load(
    network=network_new,
    device=device,
    filename=mre_1d_filename,
)

Loading posteriors

There is no option to save / load a posterior because you can simply create one from the prior and the MarginalRatioEstimator. One can also save a Bound object and load it, then give the loaded object to MarginalPosterior. Bound objects are simple to save and load.

[15]:
post_new = swyft.MarginalPosterior(mre_1d_loaded, prior_loaded, bound=None)
samples_new = post_new.sample(n_rejection_samples, observation_o)
[16]:
_ = swyft.hist1d(samples_new, kde=True, truth=v_o)
../_images/notebooks_Examples_-_3._Saving_and_Loading_19_0.png
[ ]:

This page was generated from notebooks/Examples - 3. Saving and Loading.ipynb.