from typing import (
Callable,
Dict,
Hashable,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
import numpy as np
import torch
import torch.nn as nn
import swyft.networks
from swyft.lightning.core import *
[docs]def equalize_tensors(a, b):
"""Equalize tensors, for matching minibatch size of A and B."""
n, m = len(a), len(b)
if n == m:
return a, b
elif n == 1:
shape = list(a.shape)
shape[0] = m
return a.expand(*shape), b
elif m == 1:
shape = list(b.shape)
shape[0] = n
return a, b.expand(*shape)
elif n < m:
assert m % n == 0, "Cannot equalize tensors with non-divisible batch sizes."
shape = [1 for _ in range(a.dim())]
shape[0] = m // n
return a.repeat(*shape), b
else:
assert n % m == 0, "Cannot equalize tensors with non-divisible batch sizes."
shape = [1 for _ in range(b.dim())]
shape[0] = n // m
return a, b.repeat(*shape)
[docs]class LogRatioEstimator_Ndim(torch.nn.Module):
"""Channeled MLPs for estimating multi-dimensional posteriors."""
def __init__(
self,
num_features,
marginals,
varnames=None,
dropout=0.1,
hidden_features=64,
num_blocks=2,
):
super().__init__()
self.marginals = marginals
self.ptrans = swyft.networks.ParameterTransform(
len(marginals), marginals, online_z_score=False
)
n_marginals, n_block_parameters = self.ptrans.marginal_block_shape
n_observation_features = num_features
self.classifier = swyft.networks.MarginalClassifier(
n_marginals,
n_observation_features + n_block_parameters,
hidden_features=hidden_features,
dropout_probability=dropout,
num_blocks=num_blocks,
)
if isinstance(varnames, str):
basename = varnames
varnames = []
for marg in marginals:
varnames.append([basename + "[%i]" % i for i in marg])
self.varnames = varnames
def forward(self, x, z):
x, z = equalize_tensors(x, z)
z = self.ptrans(z)
ratios = self.classifier(x, z)
w = LogRatioSamples(
ratios,
z,
np.array(self.varnames),
metadata={"type": "MarginalMLP", "marginals": self.marginals},
)
return w
# TODO: Introduce RatioEstimatorDense
class _RatioEstimatorMLPnd(torch.nn.Module):
def __init__(self, x_dim, marginals, dropout=0.1, hidden_features=64, num_blocks=2):
super().__init__()
self.marginals = marginals
self.ptrans = swyft.networks.ParameterTransform(
len(marginals), marginals, online_z_score=False
)
n_marginals, n_block_parameters = self.ptrans.marginal_block_shape
n_observation_features = x_dim
self.classifier = swyft.networks.MarginalClassifier(
n_marginals,
n_observation_features + n_block_parameters,
hidden_features=hidden_features,
dropout_probability=dropout,
num_blocks=num_blocks,
)
def forward(self, x, z):
x, z = equalize_tensors(x, z)
z = self.ptrans(z)
ratios = self.classifier(x, z)
w = LogRatioSamples(
ratios, z, metadata={"type": "MarginalMLP", "marginals": self.marginals}
)
return w
# TODO: Deprecated class (reason: Change of name)
class _RatioEstimatorMLP1d(torch.nn.Module):
def __init__(
self,
x_dim,
z_dim,
varname=None,
varnames=None,
dropout=0.1,
hidden_features=64,
num_blocks=2,
use_batch_norm=True,
ptrans_online_z_score=True,
):
"""
Default module for estimating 1-dim marginal posteriors.
Args:
x_dim: Length of feature vector.
z_dim: Length of parameter vector.
varnames: List of name of parameter vector. If a single string is provided, indices are attached automatically.
"""
print("WARNING: Deprecated, use LogRatioEstimator_1dim instead.")
super().__init__()
self.marginals = [(i,) for i in range(z_dim)]
self.ptrans = swyft.networks.ParameterTransform(
len(self.marginals), self.marginals, online_z_score=ptrans_online_z_score
)
n_marginals, n_block_parameters = self.ptrans.marginal_block_shape
n_observation_features = x_dim
self.classifier = swyft.networks.MarginalClassifier(
n_marginals,
n_observation_features + n_block_parameters,
hidden_features=hidden_features,
dropout_probability=dropout,
num_blocks=num_blocks,
use_batch_norm=use_batch_norm,
)
if isinstance(varnames, list):
self.varnames = np.array(varnames)
else:
self.varnames = np.array([varnames + "[%i]" % i for i in range(z_dim)])
def forward(self, x, z):
x, z = equalize_tensors(x, z)
zt = self.ptrans(z).detach()
logratios = self.classifier(x, zt)
w = LogRatioSamples(logratios, z, self.varnames, metadata={"type": "MLP1d"})
return w
[docs]class LogRatioEstimator_1dim(torch.nn.Module):
"""Channeled MLPs for estimating one-dimensional posteriors.
Args:
num_features: Number of features
"""
def __init__(
self,
num_features: int,
num_params,
varnames=None,
dropout=0.1,
hidden_features=64,
num_blocks=2,
use_batch_norm=True,
ptrans_online_z_score=True,
):
"""
Default module for estimating 1-dim marginal posteriors.
Args:
num_features: Length of feature vector.
num_params: Length of parameter vector.
varnames: List of name of parameter vector. If a single string is provided, indices are attached automatically.
"""
super().__init__()
self.marginals = [(i,) for i in range(num_params)]
self.ptrans = swyft.networks.ParameterTransform(
len(self.marginals), self.marginals, online_z_score=ptrans_online_z_score
)
n_marginals, n_block_parameters = self.ptrans.marginal_block_shape
n_observation_features = num_features
self.classifier = swyft.networks.MarginalClassifier(
n_marginals,
n_observation_features + n_block_parameters,
hidden_features=hidden_features,
dropout_probability=dropout,
num_blocks=num_blocks,
use_batch_norm=use_batch_norm,
)
if isinstance(varnames, list):
self.varnames = np.array([[v] for v in varnames])
else:
self.varnames = np.array(
[[varnames + "[%i]" % i] for i in range(num_params)]
)
def forward(self, x, z):
x, z = equalize_tensors(x, z)
zt = self.ptrans(z).detach()
logratios = self.classifier(x, zt)
w = LogRatioSamples(
logratios, z.unsqueeze(-1), self.varnames, metadata={"type": "MLP1d"}
)
return w
[docs]class LogRatioEstimator_1dim_Gaussian(torch.nn.Module):
"""Estimating posteriors assuming that they are Gaussian."""
def __init__(self, momentum=0.1):
super().__init__()
self.momentum = momentum
self.x_mean = None
self.z_mean = None
self.x_var = None
self.z_var = None
self.xz_cov = None
[docs] def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
"""2-dim Gaussian approximation to marginals and joint, assuming (B, N)."""
print("Warning: deprecated, might be broken")
x, z = equalize_tensors(x, z)
if self.training or self.x_mean is None:
# Covariance estimates must be based on joined samples only
# NOTE: This makes assumptions about the structure of mini batches during training (J, M, M, J, J, M, M, J, ...)
# TODO: Change to (J, M, J, M, J, M, ...) in the future
batch_size = len(x)
# idx = np.array([[i, i+3] for i in np.arange(0, batch_size, 4)]).flatten()
idx = np.arange(
batch_size // 2
) # TODO: Assuming (J, J, J, J, M, M, M, M) etc
# Estimation w/o Bessel's correction, using simple MLE estimate (https://en.wikipedia.org/wiki/Estimation_of_covariance_matrices)
x_mean_batch = x[idx].mean(dim=0).detach()
z_mean_batch = z[idx].mean(dim=0).detach()
x_var_batch = ((x[idx] - x_mean_batch) ** 2).mean(dim=0).detach()
z_var_batch = ((z[idx] - z_mean_batch) ** 2).mean(dim=0).detach()
xz_cov_batch = (
((x[idx] - x_mean_batch) * (z[idx] - z_mean_batch)).mean(dim=0).detach()
)
# Momentum-based update rule
momentum = self.momentum
self.x_mean = (
x_mean_batch
if self.x_mean is None
else (1 - momentum) * self.x_mean + momentum * x_mean_batch
)
self.x_var = (
x_var_batch
if self.x_var is None
else (1 - momentum) * self.x_var + momentum * x_var_batch
)
self.z_mean = (
z_mean_batch
if self.z_mean is None
else (1 - momentum) * self.z_mean + momentum * z_mean_batch
)
self.z_var = (
z_var_batch
if self.z_var is None
else (1 - momentum) * self.z_var + momentum * z_var_batch
)
self.xz_cov = (
xz_cov_batch
if self.xz_cov is None
else (1 - momentum) * self.xz_cov + momentum * xz_cov_batch
)
# log r(x, z) = log p(x, z)/p(x)/p(z), with covariance given by [[x_var, xz_cov], [xz_cov, z_var]]
xb = (x - self.x_mean) / self.x_var**0.5
zb = (z - self.z_mean) / self.z_var**0.5
rho = self.xz_cov / self.x_var**0.5 / self.z_var**0.5
r = (
-0.5 * torch.log(1 - rho**2)
+ rho / (1 - rho**2) * xb * zb
- 0.5 * rho**2 / (1 - rho**2) * (xb**2 + zb**2)
)
# out = torch.cat([r.unsqueeze(-1), z.unsqueeze(-1).detach()], dim=-1)
out = LogRatioSamples(r, z, metadata={"type": "Gaussian1d"})
return out