from dataclasses import dataclass, field
from toolz.dicttoolz import valmap
from typing import (
Callable,
Dict,
Hashable,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
Any,
)
import numpy as np
import torch
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
try:
from pytorch_lightning.trainer.supporters import CombinedLoader
except ImportError:
from pytorch_lightning.utilities import CombinedLoader
# from pytorch_lightning.cli import instantiate_class
import yaml
from swyft.lightning.data import *
from swyft.plot.mass import get_empirical_z_score
from swyft.lightning.utils import (
AdamW,
OnFitEndLoadBestModel,
SwyftParameterError,
_collection_mask,
_collection_flatten,
)
import scipy
from scipy.ndimage import gaussian_filter1d, gaussian_filter
# import torchist
#############
# SwyftModule
#############
class LossAggregationSteps:
def _get_logratios(self, out):
if isinstance(out, dict):
out = {k: v for k, v in out.items() if k[:4] != "aux_"}
logratios = torch.cat(
[val.logratios.flatten(start_dim=1) for val in out.values()], dim=1
)
elif isinstance(out, list) or isinstance(out, tuple):
out = [v for v in out if hasattr(v, "logratios")]
if out == []:
return None
logratios = torch.cat(
[val.logratios.flatten(start_dim=1) for val in out], dim=1
)
elif isinstance(out, swyft.LogRatioSamples):
logratios = out.logratios.flatten(start_dim=1)
else:
logratios = None
return logratios
def _get_aux_losses(self, out):
flattened_out = _collection_flatten(out)
filtered_out = [v for v in flattened_out if isinstance(v, swyft.AuxLoss)]
if len(filtered_out) == 0:
return None
else:
losses = torch.cat([v.loss.unsqueeze(-1) for v in filtered_out], dim=1)
return losses
def _calc_loss(self, batch, randomized=True):
"""Calcualte batch-averaged loss summed over ratio estimators.
Note: The expected loss for an untrained classifier (with f = 0) is
subtracted. The initial loss is hence usually close to zero.
"""
if isinstance(
batch, list
): # multiple dataloaders provided, using second one for contrastive samples
A = batch[0]
B = batch[1]
else: # only one dataloader provided, using same samples for constrative samples
A = batch
B = valmap(lambda z: torch.roll(z, 1, dims=0), A)
# Concatenate positive samples and negative (contrastive) examples
x = A
z = {}
for key in B:
z[key] = torch.cat([A[key], B[key]])
num_pos = len(list(x.values())[0]) # Number of positive examples
num_neg = len(list(z.values())[0]) - num_pos # Number of negative examples
out = self(x, z) # Evaluate network
loss_tot = 0
logratios = self._get_logratios(
out
) # Generates concatenated flattened list of all estimated log ratios
if logratios is not None:
y = torch.zeros_like(logratios)
y[:num_pos, ...] = 1
pos_weight = torch.ones_like(logratios[0]) * num_neg / num_pos
loss = F.binary_cross_entropy_with_logits(
logratios, y, reduction="none", pos_weight=pos_weight
)
num_ratios = loss.shape[1]
loss = loss.sum() / num_neg # Calculates batched-averaged loss
loss = loss - 2 * np.log(2.0) * num_ratios
loss_tot += loss
aux_losses = self._get_aux_losses(out)
if aux_losses is not None:
loss_tot += aux_losses.sum()
return loss_tot
def training_step(self, batch, batch_idx):
loss = self._calc_loss(batch)
self.log("train_loss", loss, on_step=True, on_epoch=False)
return loss
def validation_step(self, batch, batch_idx):
loss = self._calc_loss(batch, randomized=False)
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
return loss
def test_step(self, batch, batch_idx):
loss = self._calc_loss(batch, randomized=False)
self.log("test_loss", loss, on_epoch=True, on_step=False)
return loss
def predict_step(self, batch, *args, **kwargs):
A = batch[0]
B = batch[1]
return self(A, B)
[docs]
class SwyftModule(
AdamW, OnFitEndLoadBestModel, LossAggregationSteps, pl.LightningModule
):
r"""This is the central Swyft LightningModule for handling the training of logratio estimators.
Derived classes are supposed to overwrite the `forward` method in order to implement specific inference tasks.
.. note::
The forward method takes as arguments the sample batches `A` and `B`,
which typically include all sample variables. Joined samples correspond to
A=B, whereas marginal samples correspond to samples A != B.
Example usage:
.. code-block:: python
class MyNetwork(swyft.SwyftModule):
def __init__(self):
self.mlp = swyft.LogRatioEstimator_1dim(4, 4)
def forward(A, B);
x = A['x']
z = B['z']
logratios = self.mlp(x, z)
return logratios
"""
def __init__(self):
super().__init__()
#################
# LogRatioSamples
#################
[docs]
@dataclass
class AuxLoss:
r"""Datacloss for storing aditional loss functions that are minimized during optimization"""
loss: torch.Tensor
name: str
[docs]
@dataclass
class LogRatioSamples:
r"""Dataclass for storing samples of estimated log-ratio values in Swyft.
Args:
logratios: Estimated log-ratios, :math:`(\text{minibatch}, *\text{logratios_shape})`
params: Corresponding parameter valuess, :math:`(\text{minibatch}, *\text{logratios_shape}, *\text{params_shape})`
parnames: Array of parameter names, :math:`(*\text{logratios_shape})`
metadata: Optional meta-data from inference network etc.
"""
logratios: torch.Tensor
params: torch.Tensor
parnames: np.array
metadata: dict = field(default_factory=dict)
# @property
# def ratios(self):
# print("WARNING: 'ratios' deprecated")
# return self.logratios
# @property
# def values(self):
# print("WARNING: 'values' deprecated")
# return self.params
[docs]
def __len__(self):
"""Returns number of stored ratios (minibatch size)."""
assert len(self.params) == len(self.logratios), "Inconsistent Ratios"
return len(self.params)
# @property
# def weights(self):
# print("WARNING: weights is deprecated.")
# return self._get_weights(normalize = True)
# @property
# def unnormalized_weights(self):
# print("WARNING: unnormalized_weights is deprecated.")
# return self._get_weights(normalize = False)
# def _get_weights(self, normalize: bool = False):
# """Calculate weights based on ratios.
#
# Args:
# normalize: If true, normalize weights to sum to one. If false, return weights = exp(logratios).
# """
# logratios = self.logratios
# if normalize:
# logratio_max = logratios.max(axis=0).values
# weights = torch.exp(logratios-logratio_max)
# weights_total = weights.sum(axis=0)
# weights = weights/weights_total*len(weights)
# else:
# weights = torch.exp(logratios)
# return weights
# def sample(self, N, replacement = True):
# """Subsample params based on normalized weights.
#
# Args:
# N: Number of samples to generate
# replacement: Sample with replacement. Default is true, which corresponds to generating samples from the posterior.
#
# Returns:
# Tensor with samples (n_samples, ..., n_param_dims)
# """
# print("WARNING: sample method is deprecated.")
# weights = self._get_weights(normalized = True)
# if not replacement and N > len(self):
# N = len(self)
# samples = weights_sample(N, self.params, weights, replacement = replacement)
# return samples
#########
# Trainer
#########
[docs]
class SwyftTrainer(pl.Trainer):
"""Base class: pytorch_lightning.Trainer
It provides training functionality for swyft.SwyftModule. The functionality
is identical to `pytorch_lightning.Trainer`, see corresponding documentation
for more details.
Two additional methods are defined:
- `infer` for performing parameter inference tasks with a trained network
- `test_coverage` for performing coverage tests
"""
[docs]
def infer(
self, model, A, B, return_sample_ratios: bool = True, batch_size: int = 1024
):
"""Run through model in inference mode.
Args:
A: Sample, Samples, or dataloader for samples A.
B: Sample, Samples, or dataloader for samples B.
return_sample_ratios: If true (default), return results as collated collection of `LogRatioSamples` objects. Otherwise, return batches.
batch_size: batch_size used for Samples provided.
Returns:
Concatenated network output
"""
if isinstance(A, Sample):
dl1 = Samples({k: [v] for k, v in A.items()}).get_dataloader(batch_size=1)
elif isinstance(A, Samples):
dl1 = A.get_dataloader(batch_size=batch_size)
else:
dl1 = A
if isinstance(B, Sample):
dl2 = Samples({k: [v] for k, v in B.items()}).get_dataloader(batch_size=1)
elif isinstance(B, Samples):
dl2 = B.get_dataloader(batch_size=batch_size)
else:
dl2 = B
dl = CombinedLoader([dl1, dl2], mode="max_size_cycle")
ratio_batches = self.predict(model, dl)
if return_sample_ratios:
if isinstance(ratio_batches[0], dict):
keys = ratio_batches[0].keys()
d = {
k: LogRatioSamples(
torch.cat([r[k].logratios for r in ratio_batches]),
torch.cat([r[k].params for r in ratio_batches]),
ratio_batches[0][k].parnames,
)
for k in keys
if k[:4] != "aux_"
}
return d
elif isinstance(ratio_batches[0], list) or isinstance(
ratio_batches[0], tuple
):
d = [
LogRatioSamples(
torch.cat([r[i].logratios for r in ratio_batches]),
torch.cat([r[i].params for r in ratio_batches]),
ratio_batches[0][i].parnames,
)
for i in range(len(ratio_batches[0]))
if hasattr(
ratio_batches[0][i], "logratios"
) # Should we better check for Ratio class?
]
return d
else:
d = LogRatioSamples(
torch.cat([r.logratios for r in ratio_batches]),
torch.cat([r.params for r in ratio_batches]),
ratio_batches[0].parnames,
)
return d
else:
return ratio_batches
[docs]
def test_coverage(self, model, A, B, batch_size=1024, logratio_noise=True):
"""Estimate empirical mass.
Args:
model: network
A: truth samples
B: prior samples
batch_size: batch sized used during network evaluation
logratio_noise: Add a small amount of noise to log-ratio estimates, which stabilizes mass estimates for classification tasks.
Returns:
Dict of CoverageSamples objects.
"""
print("WARNING: This estimates the mass of highest-likelihood intervals.")
repeat = len(B) // batch_size + (len(B) % batch_size > 0)
pred0 = self.infer(
model, A.get_dataloader(batch_size=32), A.get_dataloader(batch_size=32)
)
pred1 = self.infer(
model,
A.get_dataloader(batch_size=1, repeat=repeat),
B.get_dataloader(batch_size=batch_size),
)
def get_pms(p0, p1):
n0 = len(p0)
ratios = p1.logratios.reshape(
n0, -1, *p1.logratios.shape[1:]
) # (n_examples, n_samples_per_example, *per_event_ratio_shape)
vs = []
ms = []
for i in range(n0):
ratio0 = p0.logratios[i]
value0 = p0.params[i]
m = _calc_mass(ratio0, ratios[i], add_noise=logratio_noise)
vs.append(value0)
ms.append(m)
masses = torch.stack(ms, dim=0)
params = torch.stack(vs, dim=0)
out = CoverageSamples(masses, params, p0.parnames)
return out
if isinstance(pred0, tuple):
out = tuple([get_pms(pred0[i], pred1[i]) for i in range(len(pred0))])
elif isinstance(pred0, list):
out = [get_pms(pred0[i], pred1[i]) for i in range(len(pred0))]
elif isinstance(pred0, dict):
out = {k: get_pms(pred0[k], pred1[k]) for k in pred0.keys()}
else:
out = get_pms(pred0, pred1)
return out
def _calc_mass(r0, r, add_noise=False):
if add_noise:
r = r + torch.rand_like(r) * 1e-3
r0 = r0 + torch.rand_like(r0) * 1e-3
p = torch.exp(r - r.max(axis=0).values)
p /= p.sum(axis=0)
m = r > r0
return (p * m).sum(axis=0)
#################
# CoverageSamples
#################
[docs]
@dataclass
class CoverageSamples:
r"""Dataclass for storing probability masses samples from coverage tests.
Args:
prob_masses: Tensor of probability masses in the range [0, 1], :math:`(\text{minibatch}, *\text{logratios_shape})`
params: Corresponding parameter valuess, :math:`(\text{minibatch}, *\text{logratios_shape}, *\text{params_shape})`
parnames: Array of parameter names, :math:`(*\text{logratios_shape})`
"""
prob_masses: torch.Tensor
params: torch.Tensor
parnames: np.array
def _get_matching_masses(self, parnames):
parnames = [parnames] if isinstance(parnames, str) else parnames
for i, pars in enumerate(self.parnames):
if set(pars) == set(parnames):
return self.prob_masses[:, i]
return None
[docs]
def estimate_coverage(
self, parnames: Union[str, Sequence[str]], z_max: float = 3.5, bins: int = 50
):
"""Estimate expected coverage of credible intervals on a grid of credibility values.
Args:
parnames: Names of parameters
z_max: upper limit on the credibility level (default 3.5)
bins (int): number of bins used when tabulating z-score
Returns:
np.array (bins, 4): Array columns correspond to [nominal z, empirical z, low_err empirical z, hi_err empirical z]
"""
m = self._get_matching_masses(parnames)
if m is None:
raise SwyftParameterError("Requested parameters not available:", parnames)
z0, z1, z2 = get_empirical_z_score(m, z_max, bins, interval_z_score=1.0)
z0 = np.tile(z0, (*z1.shape[:-1], 1))
z0 = np.reshape(z0, (*z0.shape, 1))
z1 = z1.reshape(*z1.shape, 1)
z = np.concatenate([z0, z1, z2], axis=-1)
return z