from dataclasses import dataclass, field
from toolz.dicttoolz import valmap
from typing import (
Callable,
Dict,
Hashable,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
Any,
)
import numpy as np
import torch
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
try:
from pytorch_lightning.trainer.supporters import CombinedLoader
except ImportError:
from pytorch_lightning.utilities import CombinedLoader
# from pytorch_lightning.cli import instantiate_class
import yaml
from swyft.lightning.data import *
import swyft.lightning.simulator
import scipy
from scipy.ndimage import gaussian_filter1d, gaussian_filter
import torchist
##################
# Parameter errors
##################
[docs]
class SwyftParameterError(Exception):
"""General parameter error in Swyft."""
pass
############################
# Weights, PDFs and coverage
############################
def _pdf_from_weighted_samples(v, w, bins=50, smooth=0, smooth_prior=False):
"""Take weighted samples and turn them into a pdf on a grid.
Args:
bins
"""
ndim = v.shape[-1]
if not smooth_prior:
return _weighted_smoothed_histogramdd(v, w, bins=bins, smooth=smooth)
else:
h, xy = _weighted_smoothed_histogramdd(v, w * 0 + 1, bins=bins, smooth=smooth)
if ndim == 2:
X, Y = np.meshgrid(xy[:, 0], xy[:, 1])
n = len(xy)
out = scipy.interpolate.griddata(
v, w, (X.flatten(), Y.flatten()), method="cubic", fill_value=0.0
).reshape(n, n)
out = out * h.numpy()
return out, xy
elif ndim == 1:
out = scipy.interpolate.griddata(
v[:, 0], w, xy[:, 0], method="cubic", fill_value=0.0
)
out = out * h.numpy()
return out, xy
else:
raise KeyError("Not supported")
def _weighted_smoothed_histogramdd(v, w, bins=50, smooth=0):
ndim = v.shape[-1]
if ndim == 1:
low, upp = v.min(), v.max()
h = torchist.histogramdd(v, bins, weights=w, low=low, upp=upp)
h /= len(v) * (upp - low) / bins
edges = torch.linspace(low, upp, bins + 1)
x = (edges[1:] + edges[:-1]) / 2
if smooth > 0:
h = torch.tensor(gaussian_filter1d(h, smooth))
return h, x.unsqueeze(-1)
elif ndim == 2:
low = v.min(axis=0).values
upp = v.max(axis=0).values
h = torchist.histogramdd(v, bins=bins, weights=w, low=low, upp=upp)
h /= len(v) * (upp[0] - low[0]) * (upp[1] - low[1]) / bins**2
x = torch.linspace(low[0], upp[0], bins + 1)
y = torch.linspace(low[1], upp[1], bins + 1)
x = (x[1:] + x[:-1]) / 2
y = (y[1:] + y[:-1]) / 2
xy = torch.vstack([x, y]).T
if smooth > 0:
h = torch.tensor(gaussian_filter(h * 1.0, smooth))
return h, xy
[docs]
def get_pdf(
lrs_coll,
params: Union[str, Sequence[str]],
aux=None,
bins: int = 50,
smooth: float = 0.0,
smooth_prior=False,
):
"""Generate binned PDF based on input
Args:
lrs_coll: Collection of LogRatioSamples objects.
params: Parameter names
bins: Number of bins
smooth: Apply Gaussian smoothing
smooth_prior: Smooth prior instead of posterior
Returns:
np.array, np.array: Returns densities and parameter grid.
"""
z, w = get_weighted_samples(lrs_coll, params)
if aux is not None:
z_aux, _ = get_weighted_samples(aux, params)
else:
z_aux = None
return _pdf_from_weighted_samples(
z, w, bins=bins, smooth=smooth, smooth_prior=smooth_prior
)
def _get_weights(logratios, normalize: bool = False):
"""Calculate weights based on ratios.
Args:
normalize: If true, normalize weights to sum to one. If false, return weights = exp(logratios).
"""
if normalize:
logratio_max = logratios.max(axis=0).values
weights = torch.exp(logratios - logratio_max)
weights_total = weights.sum(axis=0)
weights = weights / weights_total * len(weights)
else:
weights = torch.exp(logratios)
return weights
[docs]
def get_weighted_samples(lrs_coll, params: Union[str, Sequence[str]]):
"""Returns weighted samples for particular parameter combination.
Args:
params: (List of) parameter names
Returns:
(torch.Tensor, torch.Tensor): Parameter and weight tensors
"""
params = params if isinstance(params, list) else [params]
if not (isinstance(lrs_coll, list) or isinstance(lrs_coll, tuple)):
lrs_coll = [lrs_coll]
for l in lrs_coll:
for i, pars in enumerate(l.parnames):
if all(x in pars for x in params):
idx = [list(pars).index(x) for x in params]
params = l.params[:, i, idx]
weights = _get_weights(l.logratios, normalize=True)[:, i]
return params, weights
raise SwyftParameterError("Requested parameters not available:", *params)
[docs]
def get_class_probs(lrs_coll, params: str):
"""Return class probabilities for discrete parameters.
Args:
lrs_coll: Collection of LogRatioSamples objects
params: Parameter of interest (must be (0, 1, ..., K-1) for K classes)
Returns:
np.Array: Vector of length K with class probabilities
"""
params, weights = get_weighted_samples(lrs_coll, params)
probs = np.array(
[weights[params[:, 0] == k].sum() for k in range(int(params[:, 0].max()) + 1)]
)
probs /= probs.sum()
return probs
# def weights_sample(N, values, weights, replacement = True):
# """Weight-based sampling with or without replacement."""
# sw = weights.shape
# sv = values.shape
# assert sw == sv[:len(sw)], "Overlapping left-handed weights and values shapes do not match: %s vs %s"%(str(sv), str(sw))
#
# w = weights.view(weights.shape[0], -1)
# idx = torch.multinomial(w.T, N, replacement = replacement).T
# si = tuple(1 for _ in range(len(sv)-len(sw)))
# idx = idx.view(N, *sw[1:], *si)
# idx = idx.expand(N, *sv[1:])
#
# samples = torch.gather(values, 0, idx)
# return samples
[docs]
def estimate_coverage(cs_coll, params, z_max=3.5, bins=50):
"""Estimate coverage from collection of coverage_samples objects."""
return _collection_select(
cs_coll,
"Requested parameters not available: %s" % (params,),
"estimate_coverage",
params,
z_max=z_max,
bins=bins,
)
######
# Misc
######
[docs]
def best_from_yaml(filepath):
"""Get best model from tensorboard log. Useful for reloading trained networks.
Args:
filepath: Filename of yaml file (assumed to be saved with to_yaml from ModelCheckpoint)
Returns:
path to best model
"""
try:
with open(filepath) as f:
best_k_models = yaml.load(f, Loader=yaml.FullLoader)
except FileNotFoundError:
return None
val_loss = np.inf
path = None
for k, v in best_k_models.items():
if v < val_loss:
path = k
val_loss = v
return path
##################
# Collection utils
##################
[docs]
def param_select(parnames, target_parnames, match_exactly: bool = False):
"""Find indices of parameters of interest.
The output can be used to for instance select parameter from the LogRatioSamples object like
obj.params[idx1][idx2]
Args:
parnames: :math:`(*logratios_shape, num_params)`
target_parnames: List of parameter names of interest
match_exactly: Only return exact matches (i.e. no partial matches)
Returns:
tuple, list: idx1 (logratio index), idx2 (parameter indices)
"""
assert (
len(parnames.shape) == 2
), "`param_select` is only implemented for 1-dim logratios_shape"
for i, pars in enumerate(parnames):
if all(target_parname in pars for target_parname in target_parnames):
idx = [list(pars).index(tp) for tp in target_parnames]
if not match_exactly or len(idx) == len(target_parnames):
return (i,), idx
raise swyft.lightning.utils.SwyftParameterError(
"Requested parameters not found: %s" % target_parnames
)
def _collection_mask(coll, mask_fn):
def mask(item):
if isinstance(item, list) or isinstance(item, tuple) or isinstance(item, dict):
return True
return mask_fn(item)
if isinstance(coll, list):
return [_collection_mask(item, mask_fn) for item in coll if mask(item)]
elif isinstance(coll, tuple):
return tuple([_collection_mask(item, mask_fn) for item in coll if mask(item)])
elif isinstance(coll, dict):
return {
k: _collection_mask(item, mask_fn) for k, item in coll.items() if mask(item)
}
else:
return coll if mask(coll) else None
def _collection_map(coll, map_fn):
if isinstance(coll, list):
return [_collection_map(item, map_fn) for item in coll]
elif isinstance(coll, tuple):
return tuple([_collection_map(item, map_fn) for item in coll])
elif isinstance(coll, dict):
return {k: _collection_map(item, map_fn) for k, item in coll.items()}
else:
return map_fn(coll)
def _collection_flatten(coll, acc=None):
"""Flatten a nested list of collections by returning a list of all nested items."""
if acc is None:
acc = []
if isinstance(coll, list) or isinstance(coll, tuple):
for v in coll:
_collection_flatten(v, acc)
elif isinstance(coll, dict):
for v in coll.values():
_collection_flatten(v, acc)
else:
acc.append(coll)
return acc
def _collection_select(coll, err, fn, *args, **kwargs):
if isinstance(coll, list):
for item in coll:
try:
return _collection_select(item, err, fn, *args, **kwargs)
except SwyftParameterError:
pass
elif isinstance(coll, tuple):
for item in coll:
try:
return _collection_select(item, err, fn, *args, **kwargs)
except SwyftParameterError:
pass
elif isinstance(coll, dict):
for item in coll.values():
try:
return _collection_select(item, err, fn, *args, **kwargs)
except SwyftParameterError:
pass
else:
try:
bar = getattr(coll, fn) if fn else coll
return bar(*args, **kwargs)
except SwyftParameterError:
pass
raise SwyftParameterError(err)
##############
# Transformers
##############
def to_numpy(*args, single_precision=False):
if len(args) > 1:
result = []
for arg in args:
r = to_numpy(arg, single_precision=single_precision)
result.append(r)
return tuple(result)
x = args[0]
if isinstance(x, torch.Tensor):
if not single_precision:
return x.detach().cpu().numpy()
else:
x = x.detach().cpu()
if x.dtype == torch.float64:
x = x.float().numpy()
else:
x = x.numpy()
return x
elif isinstance(x, swyft.Samples):
return swyft.Samples(
{k: to_numpy(v, single_precision=single_precision) for k, v in x.items()}
)
elif isinstance(x, tuple):
return tuple(to_numpy(v, single_precision=single_precision) for v in x)
elif isinstance(x, list):
return [to_numpy(v, single_precision=single_precision) for v in x]
elif isinstance(x, dict):
return {k: to_numpy(v, single_precision=single_precision) for k, v in x.items()}
elif isinstance(x, np.ndarray):
if not single_precision:
return x
else:
if x.dtype == np.float64:
x = np.float32(x)
return x
else:
return x
def to_numpy32(*args):
return to_numpy(*args, single_precision=True)
def to_torch(x):
if isinstance(x, swyft.Samples):
return swyft.Samples({k: to_torch(v) for k, v in x.items()})
elif isinstance(x, dict):
return {k: to_torch(v) for k, v in x.items()}
else:
return torch.as_tensor(x)
[docs]
def collate_output(out):
"""Turn list of tensors/arrays-value dicts into dict of collated tensors or arrays"""
keys = out[0].keys()
result = {}
for key in keys:
if isinstance(out[0][key], torch.Tensor):
result[key] = torch.stack([x[key] for x in out])
else:
result[key] = np.stack([x[key] for x in out])
return result
############
# Optimizers
############
[docs]
class AdamW:
"""AdamW with early stopping.
Attributes:
- learning_rate (default 1e-3)
- weight_decay (default 0.01)
- amsgrad (default False)
- early_stopping_patience (optional, default 5)
"""
learning_rate = 1e-3 # Required for learning rate tuning
def configure_callbacks(self):
esp = getattr(self, "early_stopping_patience", 5)
early_stop = EarlyStopping(
monitor="val_loss", patience=getattr(self, "early_stopping_patience", esp)
)
checkpoint = ModelCheckpoint(monitor="val_loss")
return [early_stop, checkpoint]
def configure_optimizers(self):
weight_decay = getattr(self, "weight_decay", 0.01)
amsgrad = getattr(self, "amsgrad", False)
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.learning_rate,
weight_decay=weight_decay,
amsgrad=amsgrad,
)
return dict(optimizer=optimizer)
[docs]
class AdamWOneCycleLR:
"""AdamW with early stopping and OneCycleLR scheduler.
Attributes:
- learning_rate (default 1e-3)
- early_stopping_patience (optional, default 5)
"""
learning_rate = 1e-3
def configure_callbacks(self):
esp = getattr(self, "early_stopping_patience", 5)
early_stop = EarlyStopping(
monitor="val_loss", patience=getattr(self, "early_stopping_patience", esp)
)
checkpoint = ModelCheckpoint(monitor="val_loss")
return [early_stop, checkpoint]
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
total_steps = self.trainer.estimated_stepping_batches
lr_scheduler = {
"scheduler": torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=self.learning_rate, total_steps=total_steps
)
}
return dict(optimizer=optimizer, lr_scheduler=lr_scheduler)
[docs]
class AdamWReduceLROnPlateau:
"""AdamW with early stopping and ReduceLROnPlateau scheduler.
Attributes:
- learning_rate (default 1e-3)
- early_stopping_patience (optional, default 5)
- lr_scheduler_factor (optional, default 0.1)
- lr_scheduler_patience (optional, default 3)
"""
learning_rate = 1e-3
def configure_callbacks(self):
esp = getattr(self, "early_stopping_patience", 5)
early_stop = EarlyStopping(
monitor="val_loss", patience=getattr(self, "early_stopping_patience", esp)
)
checkpoint = ModelCheckpoint(monitor="val_loss")
return [early_stop, checkpoint]
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
lrsf = getattr(self, "lr_scheduler_factor", 0.1)
lrsp = getattr(self, "lr_scheduler_patience", 3)
lr_scheduler = {
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
factor=getattr(self, "lr_scheduler_factor", lrsf),
patience=getattr(self, "lr_scheduler_patience", lrsp),
),
"monitor": "val_loss",
}
return dict(optimizer=optimizer, lr_scheduler=lr_scheduler)
class OnFitEndLoadBestModel:
best_model_path = ""
def on_fit_end(self):
self.best_model_path = self.trainer.checkpoint_callback.best_model_path
checkpoint = torch.load(self.best_model_path)
print("Reloading best model:", self.best_model_path)
self.load_state_dict(checkpoint["state_dict"])