Source code for swyft.lightning.data

import math
from typing import (
    Callable,
    Dict,
    Hashable,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
)
import numpy as np
import torch
from torch.utils.data import random_split
import pytorch_lightning as pl
import zarr
import fasteners
import swyft
from swyft.lightning.simulator import Samples, Sample


######################
# Datasets and loaders
######################


[docs] class SwyftDataModule(pl.LightningDataModule): """DataModule to handle simulated data. Args: data: Simulation data val_fraction: Fraction of data used for validation. batch_size: Minibatch size. num_workers: Number of workers for dataloader. shuffle: Shuffle training data. Returns: pytorch_lightning.LightningDataModule """ def __init__( self, data, # lengths: Union[Sequence[int], None] = None, # fractions: Union[Sequence[float], None] = None, val_fraction: float = 0.2, batch_size: int = 32, num_workers: int = 0, shuffle: bool = False, on_after_load_sample: Optional[callable] = None, ): super().__init__() self.data = data # TODO: Clean up codes lengths = None fractions = [1 - val_fraction, val_fraction] if lengths is not None and fractions is None: assert ( len(lengths) == 2 ), "SwyftDataModule only provides training and validation data." lengths = [lengths[0], lenghts[1], 0] self.lengths = lengths elif lengths is None and fractions is not None: assert ( len(fractions) == 2 ), "SwyftDataModule only provides training and validation data." fractions = [fractions[0], fractions[1], 0] self.lengths = self._get_lengths(fractions, len(data)) else: raise ValueError("Either lenghts or fraction must be set, but not both.") self.batch_size = batch_size self.num_workers = num_workers self.shuffle = shuffle self.on_after_load_sample = on_after_load_sample @staticmethod def _get_lengths(fractions, N): fractions = np.array(fractions) fractions /= sum(fractions) mu = N * fractions n = np.floor(mu) n[0] += N - sum(n) return [int(v) for v in n] def setup(self, stage: str): if isinstance(self.data, Samples): dataset = self.data.get_dataset( on_after_load_sample=self.on_after_load_sample ) splits = torch.utils.data.random_split(dataset, self.lengths) self.dataset_train, self.dataset_val, self.dataset_test = splits elif isinstance(self.data, swyft.ZarrStore): idxr1 = (0, self.lengths[0]) idxr2 = (self.lengths[0], self.lengths[0] + self.lengths[1]) idxr3 = (self.lengths[0] + self.lengths[1], len(self.data)) self.dataset_train = self.data.get_dataset( idx_range=idxr1, on_after_load_sample=self.on_after_load_sample ) self.dataset_val = self.data.get_dataset( idx_range=idxr2, on_after_load_sample=None ) self.dataset_test = self.data.get_dataset( idx_range=idxr3, on_after_load_sample=None ) else: raise ValueError def train_dataloader(self): dataloader = torch.utils.data.DataLoader( self.dataset_train, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, ) return dataloader def val_dataloader(self): dataloader = torch.utils.data.DataLoader( self.dataset_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, ) return dataloader def test_dataloader(self): return
# dataloader = torch.utils.data.DataLoader( # self.dataset_test, # batch_size=self.batch_size, # shuffle=False, # num_workers=self.num_workers, # ) # return dataloader
[docs] class SamplesDataset(torch.utils.data.Dataset): """Simple torch dataset based on Samples.""" def __init__(self, sample_store, on_after_load_sample=None): self._dataset = sample_store self._on_after_load_sample = on_after_load_sample def __len__(self): return len(self._dataset[list(self._dataset.keys())[0]]) def __getitem__(self, i): d = {k: v[i] for k, v in self._dataset.items()} if self._on_after_load_sample is not None: d = self._on_after_load_sample(d) return d
[docs] class RepeatDatasetWrapper(torch.utils.data.Dataset): def __init__(self, dataset, repeat): self._dataset = dataset self._repeat = repeat def __len__(self): return len(self._dataset) * self._repeat def __getitem__(self, i): return self._dataset[i // self._repeat]
########### # ZarrStore ###########
[docs] class ZarrStore: r"""Storing training data in zarr archive.""" def __init__(self, file_path, sync_path=None): if sync_path is None: sync_path = file_path + ".sync" synchronizer = zarr.ProcessSynchronizer(sync_path) if sync_path else None self.store = zarr.DirectoryStore(file_path) self.root = zarr.group(store=self.store, synchronizer=synchronizer) self.lock = fasteners.InterProcessLock(file_path + ".lock.file")
[docs] def reset_length(self, N, clubber=False): """Resize store. N >= current store length.""" if N < len(self) and not clubber: raise ValueError( """New length shorter than current store length. You can use clubber = True if you know what your are doing.""" ) for k in self.data.keys(): shape = self.data[k].shape self.data[k].resize(N, *shape[1:]) self.root["meta/sim_status"].resize(N,)
def init(self, N, chunk_size, shapes=None, dtypes=None): if len(self) > 0: print("WARNING: Already initialized.") return self self._init_shapes(shapes, dtypes, N, chunk_size) return self def __len__(self): if "data" not in self.root.keys(): return 0 keys = self.root["data"].keys() ns = [len(self.root["data"][k]) for k in keys] N = ns[0] assert all([n == N for n in ns]) return N def keys(self): return list(self.data.keys()) def __getitem__(self, i): if isinstance(i, int): return Sample({k: self.data[k][i] for k in self.keys()}) elif isinstance(i, slice): return Samples({k: self.data[k][i] for k in self.keys()}) elif isinstance(i, str): return self.data[i] else: raise ValueError # TODO: Remove consistency checks def _init_shapes(self, shapes, dtypes, N, chunk_size): """Initializes shapes, or checks consistency.""" for k in shapes.keys(): s = shapes[k] dtype = dtypes[k] try: self.root.zeros( "data/" + k, shape=(N, *s), chunks=(chunk_size, *s), dtype=dtype ) except zarr.errors.ContainsArrayError: assert self.root["data/" + k].shape == ( N, *s, ), "Inconsistent array sizes" assert self.root["data/" + k].chunks == ( chunk_size, *s, ), "Inconsistent chunk sizes" assert self.root["data/" + k].dtype == dtype, "Inconsistent dtype" try: self.root.zeros( "meta/sim_status", shape=(N,), chunks=(chunk_size,), dtype="i4" ) except zarr.errors.ContainsArrayError: assert self.root["meta/sim_status"].shape == ( N, ), "Inconsistent array sizes" try: assert self.chunk_size == chunk_size, "Inconsistent chunk size" except KeyError: self.data.attrs["chunk_size"] = chunk_size @property def chunk_size(self): return self.data.attrs["chunk_size"] @property def data(self): return self.root["data"] def numpy(self): return {k: v[:] for k, v in self.root["data"].items()} def get_sample_store(self): return Samples(self.numpy()) @property def meta(self): return {k: v for k, v in self.root["meta"].items()} @property def sims_required(self): return sum(self.root["meta"]["sim_status"][:] == 0) def simulate(self, sampler, max_sims=None, batch_size=10): total_sims = 0 if isinstance(sampler, swyft.Simulator): sampler = sampler.sample while self.sims_required > 0: if max_sims is not None and total_sims >= max_sims: break num_sims = self._simulate_batch(sampler, batch_size) total_sims += num_sims def _simulate_batch(self, sample_fn, batch_size): # Run simulator num_sims = min(batch_size, self.sims_required) if num_sims == 0: return num_sims samples = sample_fn(num_sims) # Reserve slots with self.lock: sim_status = self.root["meta"]["sim_status"] data = self.root["data"] idx = np.arange(len(sim_status))[sim_status[:] == 0][:num_sims] index_slices = _get_index_slices(idx) for i_slice, j_slice in index_slices: sim_status[j_slice[0] : j_slice[1]] = 1 for k, v in data.items(): data[k][j_slice[0] : j_slice[1]] = samples[k][ i_slice[0] : i_slice[1] ] return num_sims def get_dataset(self, idx_range=None, on_after_load_sample=None): return ZarrStoreIterableDataset( self, idx_range=idx_range, on_after_load_sample=on_after_load_sample ) def get_dataloader( self, num_workers=0, batch_size=1, pin_memory=False, drop_last=True, idx_range=None, on_after_load_sample=None, ): ds = self.get_dataset( idx_range=idx_range, on_after_load_sample=on_after_load_sample ) dl = torch.utils.data.DataLoader( ds, num_workers=num_workers, batch_size=batch_size, drop_last=drop_last, pin_memory=pin_memory, ) return dl
def _get_index_slices(idx): """Returns list of enumerated consecutive indices""" idx = np.array(idx) pointer = 0 residual_idx = idx slices = [] while len(residual_idx) > 0: mask = residual_idx - residual_idx[0] - np.arange(len(residual_idx)) == 0 slc1 = [residual_idx[mask][0], residual_idx[mask][-1] + 1] slc2 = [pointer, pointer + sum(mask)] pointer += sum(mask) slices.append([slc2, slc1]) residual_idx = residual_idx[~mask] return slices
[docs] class ZarrStoreIterableDataset(torch.utils.data.dataloader.IterableDataset): def __init__( self, zarr_store: ZarrStore, idx_range=None, on_after_load_sample=None ): self.zs = zarr_store if idx_range is None: self.n_samples = len(self.zs) self.offset = 0 else: self.offset = idx_range[0] self.n_samples = idx_range[1] - idx_range[0] self.chunk_size = self.zs.chunk_size self.n_chunks = int(math.ceil(self.n_samples / float(self.chunk_size))) self.on_after_load_sample = on_after_load_sample @staticmethod def get_idx(n_chunks, worker_info): if worker_info is not None: num_workers = worker_info.num_workers worker_id = worker_info.id n_chunks_per_worker = int(math.ceil(n_chunks / float(num_workers))) idx = [ worker_id * n_chunks_per_worker, min((worker_id + 1) * n_chunks_per_worker, n_chunks), ] idx = np.random.permutation(range(*idx)) else: idx = np.random.permutation(n_chunks) return idx def __iter__(self): worker_info = torch.utils.data.get_worker_info() idx = self.get_idx(self.n_chunks, worker_info) offset = self.offset for i0 in idx: # Read in chunks data_chunk = {} for k in self.zs.data.keys(): data_chunk[k] = self.zs.data[k][ offset + i0 * self.chunk_size : offset + (i0 + 1) * self.chunk_size ] n = len(data_chunk[k]) # Return separate samples for i in np.random.permutation(n): out = {k: v[i] for k, v in data_chunk.items()} if self.on_after_load_sample: out = self.on_after_load_sample(out) yield out
# def get_ntrain_nvalid( # validation_amount: Union[float, int], len_dataset: int # ) -> Tuple[int, int]: # """Divide a dataset into a training and validation set. # # Args: # validation_amount: percentage or number of elements in the validation set # len_dataset: total length of the dataset # # Raises: # TypeError: When the validation_amount is neither a float or int. # # Returns: # (n_train, n_valid) # """ # assert validation_amount > 0 # if isinstance(validation_amount, float): # percent_validation = validation_amount # percent_train = 1.0 - percent_validation # n_valid, n_train = split_length_by_percentage( # len_dataset, (percent_validation, percent_train) # ) # if n_valid % 2 != 0: # n_valid += 1 # n_train -= 1 # elif isinstance(validation_amount, int): # n_valid = validation_amount # n_train = len_dataset - n_valid # assert n_train > 0 # # if n_valid % 2 != 0: # n_valid += 1 # n_train -= 1 # else: # raise TypeError("validation_amount must be int or float") # return n_train, n_valid ## TODO: Deprecate # class SwyftDataModule_deprecated(pl.LightningDataModule): # def __init__( # self, # on_after_load_sample=None, # store=None, # batch_size: int = 32, # validation_percentage=0.2, # manual_seed=None, # train_multiply=10, # num_workers=0, # ): # super().__init__() # self.store = store # self.on_after_load_sample = on_after_load_sample # self.batch_size = batch_size # self.num_workers = num_workers # self.validation_percentage = validation_percentage # self.train_multiply = train_multiply # print( # "Deprecation warning: Use dataloaders directly rathe than this data module for transparency." # ) # # def setup(self, stage): # self.dataset = SamplesDataset( # self.store, on_after_load_sample=self.on_after_load_sample # ) # , x_keys = ['data'], z_keys=['z']) # n_train, n_valid = get_ntrain_nvalid( # self.validation_percentage, len(self.dataset) # ) # self.dataset_train, self.dataset_valid = random_split( # self.dataset, # [n_train, n_valid], # generator=torch.Generator().manual_seed(42), # ) # self.dataset_test = SamplesDataset( # self.store # ) # , x_keys = ['data'], z_keys=['z']) # # def train_dataloader(self): # return torch.utils.data.DataLoader( # self.dataset_train, batch_size=self.batch_size, num_workers=self.num_workers # ) # # def val_dataloader(self): # return torch.utils.data.DataLoader( # self.dataset_valid, batch_size=self.batch_size, num_workers=self.num_workers # ) # # # # TODO: Deprecate # # def predict_dataloader(self): # # return torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, num_workers = self.num_workers) # # def test_dataloader(self): # return torch.utils.data.DataLoader( # self.dataset_test, batch_size=self.batch_size, num_workers=self.num_workers # ) # # def samples(self, N, random=False): # dataloader = torch.utils.data.DataLoader( # self.dataset_train, batch_size=N, num_workers=0, shuffle=random # ) # examples = next(iter(dataloader)) # return Samples(examples)