Source code for swyft.lightning.estimators

from typing import (
    Callable,
    Dict,
    Hashable,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
)
import numpy as np
import torch
import torch.nn as nn
import swyft.networks

from swyft.lightning.core import *


[docs] def equalize_tensors(a, b): """Equalize tensors, for matching minibatch size of A and B.""" n, m = len(a), len(b) if n == m: return a, b elif n == 1: shape = list(a.shape) shape[0] = m return a.expand(*shape), b elif m == 1: shape = list(b.shape) shape[0] = n return a, b.expand(*shape) elif n < m: assert m % n == 0, "Cannot equalize tensors with non-divisible batch sizes." shape = [1 for _ in range(a.dim())] shape[0] = m // n return a.repeat(*shape), b else: assert n % m == 0, "Cannot equalize tensors with non-divisible batch sizes." shape = [1 for _ in range(b.dim())] shape[0] = n // m return a, b.repeat(*shape)
[docs] class LogRatioEstimator_Ndim(torch.nn.Module): """Channeled MLPs for estimating multi-dimensional posteriors.""" def __init__( self, num_features, marginals, varnames=None, dropout=0.1, hidden_features=64, num_blocks=2, Lmax=0, ): super().__init__() self.marginals = marginals self.ptrans = swyft.networks.ParameterTransform( len(marginals), marginals, online_z_score=False ) n_marginals, n_block_parameters = self.ptrans.marginal_block_shape n_observation_features = num_features self.classifier = swyft.networks.MarginalClassifier( n_marginals, n_observation_features + n_block_parameters, hidden_features=hidden_features, dropout_probability=dropout, num_blocks=num_blocks, Lmax=Lmax, ) if isinstance(varnames, str): basename = varnames varnames = [] for marg in marginals: varnames.append([basename + "[%i]" % i for i in marg]) self.varnames = varnames def forward(self, x, z): x, z = equalize_tensors(x, z) z = self.ptrans(z) ratios = self.classifier(x, z) w = LogRatioSamples( ratios, z, np.array(self.varnames), metadata={"type": "MarginalMLP", "marginals": self.marginals}, ) return w
# TODO: Introduce RatioEstimatorDense class _RatioEstimatorMLPnd(torch.nn.Module): def __init__(self, x_dim, marginals, dropout=0.1, hidden_features=64, num_blocks=2): super().__init__() self.marginals = marginals self.ptrans = swyft.networks.ParameterTransform( len(marginals), marginals, online_z_score=False ) n_marginals, n_block_parameters = self.ptrans.marginal_block_shape n_observation_features = x_dim self.classifier = swyft.networks.MarginalClassifier( n_marginals, n_observation_features + n_block_parameters, hidden_features=hidden_features, dropout_probability=dropout, num_blocks=num_blocks, ) def forward(self, x, z): x, z = equalize_tensors(x, z) z = self.ptrans(z) ratios = self.classifier(x, z) w = LogRatioSamples( ratios, z, metadata={"type": "MarginalMLP", "marginals": self.marginals} ) return w # TODO: Deprecated class (reason: Change of name) class _RatioEstimatorMLP1d(torch.nn.Module): def __init__( self, x_dim, z_dim, varname=None, varnames=None, dropout=0.1, hidden_features=64, num_blocks=2, use_batch_norm=True, ptrans_online_z_score=True, ): """ Default module for estimating 1-dim marginal posteriors. Args: x_dim: Length of feature vector. z_dim: Length of parameter vector. varnames: List of name of parameter vector. If a single string is provided, indices are attached automatically. """ print("WARNING: Deprecated, use LogRatioEstimator_1dim instead.") super().__init__() self.marginals = [(i,) for i in range(z_dim)] self.ptrans = swyft.networks.ParameterTransform( len(self.marginals), self.marginals, online_z_score=ptrans_online_z_score ) n_marginals, n_block_parameters = self.ptrans.marginal_block_shape n_observation_features = x_dim self.classifier = swyft.networks.MarginalClassifier( n_marginals, n_observation_features + n_block_parameters, hidden_features=hidden_features, dropout_probability=dropout, num_blocks=num_blocks, use_batch_norm=use_batch_norm, ) if isinstance(varnames, list): self.varnames = np.array(varnames) else: self.varnames = np.array([varnames + "[%i]" % i for i in range(z_dim)]) def forward(self, x, z): x, z = equalize_tensors(x, z) zt = self.ptrans(z).detach() logratios = self.classifier(x, zt) w = LogRatioSamples(logratios, z, self.varnames, metadata={"type": "MLP1d"}) return w
[docs] class LogRatioEstimator_1dim(torch.nn.Module): """Channeled MLPs for estimating one-dimensional posteriors. Args: num_features: Number of features """ def __init__( self, num_features: int, num_params, varnames=None, dropout=0.1, hidden_features=64, num_blocks=2, use_batch_norm=True, ptrans_online_z_score=True, Lmax=0, ): """ Default module for estimating 1-dim marginal posteriors. Args: num_features: Length of feature vector. num_params: Length of parameter vector. varnames: List of name of parameter vector. If a single string is provided, indices are attached automatically. """ super().__init__() self.marginals = [(i,) for i in range(num_params)] self.ptrans = swyft.networks.ParameterTransform( len(self.marginals), self.marginals, online_z_score=ptrans_online_z_score ) n_marginals, n_block_parameters = self.ptrans.marginal_block_shape n_observation_features = num_features self.classifier = swyft.networks.MarginalClassifier( n_marginals, n_observation_features + n_block_parameters, hidden_features=hidden_features, dropout_probability=dropout, num_blocks=num_blocks, use_batch_norm=use_batch_norm, Lmax=Lmax, ) if isinstance(varnames, list): self.varnames = np.array([[v] for v in varnames]) else: self.varnames = np.array( [[varnames + "[%i]" % i] for i in range(num_params)] ) def forward(self, x, z): x, z = equalize_tensors(x, z) zt = self.ptrans(z).detach() logratios = self.classifier(x, zt) w = LogRatioSamples( logratios, z.unsqueeze(-1), self.varnames, metadata={"type": "MLP1d"} ) return w
[docs] class LogRatioEstimator_1dim_Gaussian(torch.nn.Module): """Estimating posteriors assuming that they are Gaussian. DEPRECATED: Use LogRatioEstimator_Gaussian instead. """ def __init__( self, num_params, varnames=None, momentum: float = 0.1, minstd: float = 1e-3 ): r""" Default module for estimating 1-dim marginal posteriors, using Gaussian approximations. Args: num_params: Length of parameter vector. varnames: List of name of parameter vector. If a single string is provided, indices are attached automatically. momentum: Momentum for running estimate for variance and covariances. minstd: Minimum relative standard deviation of prediction variable. The correlation coefficient will be truncated in the range :math:`\rho = \pm \sqrt{1-\text{minstd}^2}` .. note:: This module performs running estimates of parameter variances and covariances. There are no learnable parameters. This can cause errors when using the module in isolation without other modules with learnable parameters. The covariance estimates are based on joined samples only. The first n_batch samples of z are assumed to be joined jointly drawn, where n_batch is the batch size of x. """ super().__init__() self.momentum = momentum self.x_mean = None self.z_mean = None self.x_var = None self.z_var = None self.xz_cov = None self.minstd = minstd if isinstance(varnames, list): self.varnames = np.array([[v] for v in varnames]) else: self.varnames = np.array( [[varnames + "[%i]" % i] for i in range(num_params)] )
[docs] def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """2-dim Gaussian approximation to marginals and joint, assuming (B, N).""" if self.training or self.x_mean is None: batch_size = len(x) idx = np.arange(batch_size) # Estimation w/o Bessel's correction, using simple MLE estimate (https://en.wikipedia.org/wiki/Estimation_of_covariance_matrices) x_mean_batch = x[idx].mean(dim=0).detach() z_mean_batch = z[idx].mean(dim=0).detach() x_var_batch = ((x[idx] - x_mean_batch) ** 2).mean(dim=0).detach() z_var_batch = ((z[idx] - z_mean_batch) ** 2).mean(dim=0).detach() xz_cov_batch = ( ((x[idx] - x_mean_batch) * (z[idx] - z_mean_batch)).mean(dim=0).detach() ) # Momentum-based update rule momentum = self.momentum self.x_mean = ( x_mean_batch if self.x_mean is None else (1 - momentum) * self.x_mean + momentum * x_mean_batch ) self.x_var = ( x_var_batch if self.x_var is None else (1 - momentum) * self.x_var + momentum * x_var_batch ) self.z_mean = ( z_mean_batch if self.z_mean is None else (1 - momentum) * self.z_mean + momentum * z_mean_batch ) self.z_var = ( z_var_batch if self.z_var is None else (1 - momentum) * self.z_var + momentum * z_var_batch ) self.xz_cov = ( xz_cov_batch if self.xz_cov is None else (1 - momentum) * self.xz_cov + momentum * xz_cov_batch ) # log r(x, z) = log p(x, z)/p(x)/p(z), with covariance given by [[x_var, xz_cov], [xz_cov, z_var]] x, z = swyft.equalize_tensors(x, z) xb = (x - self.x_mean) / self.x_var**0.5 zb = (z - self.z_mean) / self.z_var**0.5 rho = self.xz_cov / self.x_var**0.5 / self.z_var**0.5 rho = torch.clip( rho, min=-((1 - self.minstd**2) ** 0.5), max=(1 - self.minstd**2) ** 0.5 ) logratios = ( -0.5 * torch.log(1 - rho**2) + rho / (1 - rho**2) * xb * zb - 0.5 * rho**2 / (1 - rho**2) * (xb**2 + zb**2) ) out = LogRatioSamples( logratios, z.unsqueeze(-1), self.varnames, metadata={"type": "Gaussian1d"} ) return out
def get_z_estimate(self, x): z_estimator = (x - self.x_mean) * self.xz_cov / self.x_var**0.5 + self.z_mean return z_estimator
[docs] class LogRatioEstimator_Autoregressive(nn.Module): r"""Conventional autoregressive model, based on swyft.LogRatioEstimator_1dim.""" def __init__( self, num_features, num_params, varnames, dropout=0.1, num_blocks=2, hidden_features=64, ): super().__init__() self.cl1 = swyft.LogRatioEstimator_1dim( num_features=num_features + num_params, num_params=num_params, varnames=varnames, dropout=dropout, num_blocks=num_blocks, hidden_features=hidden_features, Lmax=0, ) self.cl2 = swyft.LogRatioEstimator_1dim( num_features=num_features + num_params, num_params=num_params, varnames=varnames, dropout=dropout, num_blocks=num_blocks, hidden_features=hidden_features, Lmax=0, ) self.num_params = num_params def forward(self, xA, zA, zB): xA, zB = swyft.equalize_tensors(xA, zB) xA, zA = swyft.equalize_tensors(xA, zA) fA = torch.cat([xA, zA], dim=-1) fA = fA.unsqueeze(1) fA = fA.repeat((1, self.num_params, 1)) mask = torch.ones(self.num_params, fA.shape[-1], device=fA.device) for i in range(self.num_params): mask[i, -self.num_params + i :] = 0 fA = fA * mask logratios1 = self.cl1(fA, zB) fA = torch.cat([xA * 0, zA], dim=-1) fA = fA.unsqueeze(1) fA = fA.repeat((1, self.num_params, 1)) mask = torch.ones(self.num_params, fA.shape[-1], device=fA.device) for i in range(self.num_params): mask[i, -self.num_params + i :] = 0 fA = fA * mask logratios2 = self.cl2(fA, zB) l1 = logratios1.logratios.sum(-1) l2 = logratios2.logratios.sum(-1) l2 = torch.where(l2 > 0, l2, 0) l = (l1 - l2).detach().unsqueeze(-1) logratios_tot = swyft.LogRatioSamples(l, logratios1.params, logratios1.parnames) return dict( lrs_total=logratios_tot, lrs_partials1=logratios1, lrs_partials2=logratios2 )
[docs] class LogRatioEstimator_Gaussian(torch.nn.Module): """Estimating posteriors with Gaussian approximation. Args: num_params: Length of parameter vector. varnames: List of name of parameter vector. If a single string is provided, indices are attached automatically. momentum: Momentum of covariance and mean estimates minstd: Minimum standard deviation to enforce numerical stability """ def __init__( self, num_params, varnames=None, momentum: float = 0.02, minstd: float = 1e-10 ): super().__init__() self._momentum = momentum self._mean = None self._cov = None self._minstd = minstd if isinstance(varnames, list): self.varnames = np.array([[v] for v in varnames]) else: self.varnames = np.array( [[varnames + "[%i]" % i] for i in range(num_params)] ) @staticmethod def _get_mean_cov(x, correction=1): # (B, *, D) mean = x.mean(dim=0) # (*, D) diffs = x - mean # (B, *, D) N = len(x) covs = torch.einsum( diffs.unsqueeze(-1), [0, ...], diffs.unsqueeze(-2), [0, ...], [...] ) / (N - correction) return mean, covs # (*, D), (*, D, D) @property def cov(self): return ( self._cov + torch.eye(self._mean.shape[-1]).to(self._cov.device) * self._minstd**2 ) @property def mean(self): return self._mean
[docs] def forward(self, a: torch.Tensor, b: torch.Tensor): """Gaussian approximation to marginals and joint, assuming (B, N). a shape: (B, N, D1) b shape: (B, N, D2) """ a_dim = a.shape[-1] b_dim = b.shape[-1] if self.training or self._mean is None: batch_size = len(a) idx = np.arange(batch_size) X = torch.cat([a[idx], b[idx]], dim=-1).detach() # Estimation w/o Bessel's correction # Using simple MLE estimate (https://en.wikipedia.org/wiki/Estimation_of_covariance_matrices) mean_batch, cov_batch = self._get_mean_cov(X, correction=0) # Momentum-based update rule momentum = self._momentum self._mean = ( mean_batch if self._mean is None else (1 - momentum) * self._mean + momentum * mean_batch ) self._cov = ( cov_batch if self._cov is None else (1 - momentum) * self._cov + momentum * cov_batch ) cov = self.cov # Match tensor batch dimensions a, b = swyft.equalize_tensors(a, b) # Get standard normal distributed parameters X = torch.cat([a, b], dim=-1).double() dist_ab = torch.distributions.multivariate_normal.MultivariateNormal( self._mean, covariance_matrix=cov.double() ) logprobs_ab = dist_ab.log_prob(X) dist_b = torch.distributions.multivariate_normal.MultivariateNormal( self._mean[..., a_dim:], covariance_matrix=cov[..., a_dim:, a_dim:].double() ) logprobs_b = dist_b.log_prob(X[..., a_dim:]) dist_a = torch.distributions.multivariate_normal.MultivariateNormal( self._mean[..., :a_dim], covariance_matrix=cov[..., :a_dim, :a_dim].double() ) logprobs_a = dist_a.log_prob(X[..., :a_dim]) logratios = logprobs_ab - logprobs_b - logprobs_a lrs = swyft.LogRatioSamples(logratios, a, self.varnames) return lrs