Source code for swyft.lightning.estimators

from typing import (
    Callable,
    Dict,
    Hashable,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
)
import numpy as np
import torch
import torch.nn as nn
import swyft.networks

from swyft.lightning.core import *


[docs]def equalize_tensors(a, b): """Equalize tensors, for matching minibatch size of A and B.""" n, m = len(a), len(b) if n == m: return a, b elif n == 1: shape = list(a.shape) shape[0] = m return a.expand(*shape), b elif m == 1: shape = list(b.shape) shape[0] = n return a, b.expand(*shape) elif n < m: assert m % n == 0, "Cannot equalize tensors with non-divisible batch sizes." shape = [1 for _ in range(a.dim())] shape[0] = m // n return a.repeat(*shape), b else: assert n % m == 0, "Cannot equalize tensors with non-divisible batch sizes." shape = [1 for _ in range(b.dim())] shape[0] = n // m return a, b.repeat(*shape)
[docs]class LogRatioEstimator_Ndim(torch.nn.Module): """Channeled MLPs for estimating multi-dimensional posteriors.""" def __init__( self, num_features, marginals, varnames=None, dropout=0.1, hidden_features=64, num_blocks=2, ): super().__init__() self.marginals = marginals self.ptrans = swyft.networks.ParameterTransform( len(marginals), marginals, online_z_score=False ) n_marginals, n_block_parameters = self.ptrans.marginal_block_shape n_observation_features = num_features self.classifier = swyft.networks.MarginalClassifier( n_marginals, n_observation_features + n_block_parameters, hidden_features=hidden_features, dropout_probability=dropout, num_blocks=num_blocks, ) if isinstance(varnames, str): basename = varnames varnames = [] for marg in marginals: varnames.append([basename + "[%i]" % i for i in marg]) self.varnames = varnames def forward(self, x, z): x, z = equalize_tensors(x, z) z = self.ptrans(z) ratios = self.classifier(x, z) w = LogRatioSamples( ratios, z, np.array(self.varnames), metadata={"type": "MarginalMLP", "marginals": self.marginals}, ) return w
# TODO: Introduce RatioEstimatorDense class _RatioEstimatorMLPnd(torch.nn.Module): def __init__(self, x_dim, marginals, dropout=0.1, hidden_features=64, num_blocks=2): super().__init__() self.marginals = marginals self.ptrans = swyft.networks.ParameterTransform( len(marginals), marginals, online_z_score=False ) n_marginals, n_block_parameters = self.ptrans.marginal_block_shape n_observation_features = x_dim self.classifier = swyft.networks.MarginalClassifier( n_marginals, n_observation_features + n_block_parameters, hidden_features=hidden_features, dropout_probability=dropout, num_blocks=num_blocks, ) def forward(self, x, z): x, z = equalize_tensors(x, z) z = self.ptrans(z) ratios = self.classifier(x, z) w = LogRatioSamples( ratios, z, metadata={"type": "MarginalMLP", "marginals": self.marginals} ) return w # TODO: Deprecated class (reason: Change of name) class _RatioEstimatorMLP1d(torch.nn.Module): def __init__( self, x_dim, z_dim, varname=None, varnames=None, dropout=0.1, hidden_features=64, num_blocks=2, use_batch_norm=True, ptrans_online_z_score=True, ): """ Default module for estimating 1-dim marginal posteriors. Args: x_dim: Length of feature vector. z_dim: Length of parameter vector. varnames: List of name of parameter vector. If a single string is provided, indices are attached automatically. """ print("WARNING: Deprecated, use LogRatioEstimator_1dim instead.") super().__init__() self.marginals = [(i,) for i in range(z_dim)] self.ptrans = swyft.networks.ParameterTransform( len(self.marginals), self.marginals, online_z_score=ptrans_online_z_score ) n_marginals, n_block_parameters = self.ptrans.marginal_block_shape n_observation_features = x_dim self.classifier = swyft.networks.MarginalClassifier( n_marginals, n_observation_features + n_block_parameters, hidden_features=hidden_features, dropout_probability=dropout, num_blocks=num_blocks, use_batch_norm=use_batch_norm, ) if isinstance(varnames, list): self.varnames = np.array(varnames) else: self.varnames = np.array([varnames + "[%i]" % i for i in range(z_dim)]) def forward(self, x, z): x, z = equalize_tensors(x, z) zt = self.ptrans(z).detach() logratios = self.classifier(x, zt) w = LogRatioSamples(logratios, z, self.varnames, metadata={"type": "MLP1d"}) return w
[docs]class LogRatioEstimator_1dim(torch.nn.Module): """Channeled MLPs for estimating one-dimensional posteriors. Args: num_features: Number of features """ def __init__( self, num_features: int, num_params, varnames=None, dropout=0.1, hidden_features=64, num_blocks=2, use_batch_norm=True, ptrans_online_z_score=True, ): """ Default module for estimating 1-dim marginal posteriors. Args: num_features: Length of feature vector. num_params: Length of parameter vector. varnames: List of name of parameter vector. If a single string is provided, indices are attached automatically. """ super().__init__() self.marginals = [(i,) for i in range(num_params)] self.ptrans = swyft.networks.ParameterTransform( len(self.marginals), self.marginals, online_z_score=ptrans_online_z_score ) n_marginals, n_block_parameters = self.ptrans.marginal_block_shape n_observation_features = num_features self.classifier = swyft.networks.MarginalClassifier( n_marginals, n_observation_features + n_block_parameters, hidden_features=hidden_features, dropout_probability=dropout, num_blocks=num_blocks, use_batch_norm=use_batch_norm, ) if isinstance(varnames, list): self.varnames = np.array([[v] for v in varnames]) else: self.varnames = np.array( [[varnames + "[%i]" % i] for i in range(num_params)] ) def forward(self, x, z): x, z = equalize_tensors(x, z) zt = self.ptrans(z).detach() logratios = self.classifier(x, zt) w = LogRatioSamples( logratios, z.unsqueeze(-1), self.varnames, metadata={"type": "MLP1d"} ) return w
[docs]class LogRatioEstimator_1dim_Gaussian(torch.nn.Module): """Estimating posteriors assuming that they are Gaussian.""" def __init__( self, num_params, varnames=None, momentum: float = 0.1, minstd: float = 1e-3 ): r""" Default module for estimating 1-dim marginal posteriors, using Gaussian approximations. Args: num_params: Length of parameter vector. varnames: List of name of parameter vector. If a single string is provided, indices are attached automatically. momentum: Momentum for running estimate for variance and covariances. minstd: Minimum relative standard deviation of prediction variable. The correlation coefficient will be truncated in the range :math:`\rho = \pm \sqrt{1-\text{minstd}^2}` .. note:: This module performs running estimates of parameter variances and covariances. There are no learnable parameters. This can cause errors when using the module in isolation without other modules with learnable parameters. The covariance estimates are based on joined samples only. The first n_batch samples of z are assumed to be joined jointly drawn, where n_batch is the batch size of x. """ super().__init__() self.momentum = momentum self.x_mean = None self.z_mean = None self.x_var = None self.z_var = None self.xz_cov = None self.minstd = minstd if isinstance(varnames, list): self.varnames = np.array([[v] for v in varnames]) else: self.varnames = np.array( [[varnames + "[%i]" % i] for i in range(num_params)] )
[docs] def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """2-dim Gaussian approximation to marginals and joint, assuming (B, N).""" if self.training or self.x_mean is None: batch_size = len(x) idx = np.arange(batch_size) # Estimation w/o Bessel's correction, using simple MLE estimate (https://en.wikipedia.org/wiki/Estimation_of_covariance_matrices) x_mean_batch = x[idx].mean(dim=0).detach() z_mean_batch = z[idx].mean(dim=0).detach() x_var_batch = ((x[idx] - x_mean_batch) ** 2).mean(dim=0).detach() z_var_batch = ((z[idx] - z_mean_batch) ** 2).mean(dim=0).detach() xz_cov_batch = ( ((x[idx] - x_mean_batch) * (z[idx] - z_mean_batch)).mean(dim=0).detach() ) # Momentum-based update rule momentum = self.momentum self.x_mean = ( x_mean_batch if self.x_mean is None else (1 - momentum) * self.x_mean + momentum * x_mean_batch ) self.x_var = ( x_var_batch if self.x_var is None else (1 - momentum) * self.x_var + momentum * x_var_batch ) self.z_mean = ( z_mean_batch if self.z_mean is None else (1 - momentum) * self.z_mean + momentum * z_mean_batch ) self.z_var = ( z_var_batch if self.z_var is None else (1 - momentum) * self.z_var + momentum * z_var_batch ) self.xz_cov = ( xz_cov_batch if self.xz_cov is None else (1 - momentum) * self.xz_cov + momentum * xz_cov_batch ) # log r(x, z) = log p(x, z)/p(x)/p(z), with covariance given by [[x_var, xz_cov], [xz_cov, z_var]] x, z = swyft.equalize_tensors(x, z) xb = (x - self.x_mean) / self.x_var**0.5 zb = (z - self.z_mean) / self.z_var**0.5 rho = self.xz_cov / self.x_var**0.5 / self.z_var**0.5 rho = torch.clip( rho, min=-((1 - self.minstd**2) ** 0.5), max=(1 - self.minstd**2) ** 0.5 ) logratios = ( -0.5 * torch.log(1 - rho**2) + rho / (1 - rho**2) * xb * zb - 0.5 * rho**2 / (1 - rho**2) * (xb**2 + zb**2) ) out = LogRatioSamples( logratios, z.unsqueeze(-1), self.varnames, metadata={"type": "Gaussian1d"} ) return out