Source code for swyft.inference.marginalratioestimator

from typing import (
    Callable,
    Dict,
    Hashable,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
)

import numpy as np
import torch
import torch.nn as nn
from toolz.dicttoolz import valmap
from torch.utils.data import DataLoader, Dataset, random_split

import swyft
import swyft.utils
from swyft.saveable import StateDictSaveable, StateDictSaveableType
from swyft.types import Array, Device, MarginalIndex, MarginalToArray, ObsType, PathType
from swyft.utils.array import array_to_tensor, dict_array_to_tensor
from swyft.utils.marginals import tupleize_marginal_indices

SchedulerType = Union[
    torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
]


def split_length_by_percentage(length: int, percents: Sequence[float]) -> Sequence[int]:
    """Given the length of a sequence, return the indices which would divide it into `percents` parts.
    Any rounding errors go into the first part.

    Args:
        length
        percents

    Returns:
        length_of_parts
    """
    assert np.isclose(sum(percents), 1.0), f"{percents} does not sum to 1."
    lengths = [int(percent * length) for percent in percents]

    # Any extra from round off goes to the first split.
    difference = length - sum(lengths)
    lengths[0] += difference
    assert length == sum(
        lengths
    ), f"Splitting into {lengths} should equal total {length}."
    return lengths


def get_ntrain_nvalid(
    validation_amount: Union[float, int], len_dataset: int
) -> Tuple[int, int]:
    """Divide a dataset into a training and validation set.

    Args:
        validation_amount: percentage or number of elements in the validation set
        len_dataset: total length of the dataset

    Raises:
        TypeError: When the validation_amount is neither a float or int.

    Returns:
        (n_train, n_valid)
    """
    assert validation_amount > 0
    if isinstance(validation_amount, float):
        percent_validation = validation_amount
        percent_train = 1.0 - percent_validation
        n_valid, n_train = split_length_by_percentage(
            len_dataset, (percent_validation, percent_train)
        )
        if n_valid % 2 != 0:
            n_valid += 1
            n_train -= 1
    elif isinstance(validation_amount, int):
        n_valid = validation_amount
        n_train = len_dataset - n_valid
        assert n_train > 0

        if n_valid % 2 != 0:
            n_valid += 1
            n_train -= 1
    else:
        raise TypeError("validation_amount must be int or float")
    return n_train, n_valid


def double_observation(f: torch.Tensor) -> torch.Tensor:
    """Double observation vector as (A, B, C, D) --> (A, A, B, B, C, C, D, D)

    Args:
        f: Observation vectors (n_batch, ...)

    Returns:
        Observation vectors (2*n_batch, ...)
    """
    return torch.repeat_interleave(f, 2, dim=0)


def double_parameters(parameters: torch.Tensor) -> torch.Tensor:
    """Double parameters as (A, B, C, D) --> (A, B, A, B, C, D, C, D) etc

    Args:
        parameters: Parameter vectors (n_batch, n_parameters).

    Returns:
        parameters with shape (2*n_batch, n_parameters).
    """
    n_batch, n_parameters = parameters.shape
    assert n_batch % 2 == 0, "n_batch must be divisible by two."
    out = torch.repeat_interleave(parameters.view(-1, 2 * n_parameters), 2, dim=0).view(
        -1, n_parameters
    )
    return out


MarginalRatioEstimatorType = TypeVar(
    "MarginalRatioEstimatorType", bound="MarginalRatioEstimator"
)


[docs]class MarginalRatioEstimator(StateDictSaveable): """Handles the training and evaluation of a ratio estimator. Which ratios are defined by the `marginal_indices` attribute. The network must take observation dictionaries and parameter arrays and produce estimated an `log_ratio` for every marginal of interest. """ def __init__( self, marginal_indices: MarginalIndex, network: nn.Module, device: Device, ) -> None: """Define the marginals of interest with `marginal_indices` and the estimator architechture with `network`. Args: marginal_indices: marginals of interest defined by the parameter index network: a neural network which accepts `observation` and `parameters` and returns `len(marginal_indices)` ratios. device """ self.marginal_indices = tupleize_marginal_indices(marginal_indices) self.device = device self.network = network self.network.to(self.device) self.epoch = None self.best_network_state_dict = None self.min_loss = float("-Inf") self.optimizer = None self.scheduler = None
[docs] def train( self, dataset: Dataset, batch_size: int = 50, learning_rate: float = 5e-4, validation_percentage: float = 0.1, optimizer: Callable = torch.optim.Adam, scheduler: Optional[Callable] = torch.optim.lr_scheduler.ReduceLROnPlateau, scheduler_kwargs: dict = {"factor": 0.1, "patience": 5}, early_stopping_patience: Optional[int] = 25, max_epochs: int = 2**31 - 1, nworkers: int = 0, non_blocking: bool = True, pin_memory: bool = True, ) -> None: """Train the ratio estimator based off of a `dataset` containing observation and parameter pairs. Note: if the network has already been trained, training will resume where it left off. This effectively ignores `optimizer`, `learning_rate`, `scheduler`, and `scheduler_args`. Args: dataset: torch dataset which returns a tuple of (`observation`, `parameters`) batch_size learning_rate validation_percentage: Approximates the percentage of `dataset` used in the validation set optimizer: from `torch.optim` optimizer. It can only accept two arguments: `parameters` and `lr`. Need more arguments? Use `functools.partial`. scheduler: from `torch.optim.lr_scheduler` scheduler_kwargs: The arguments which get passed to `scheduler` early_stopping_patience: after this many fuitless epochs, training stops max_epochs: maximum number of epochs to train nworkers: number of workers to divide `dataloader` duties between. 0 implies one thread for training and dataloading. non_blocking: consult torch documentation, generally use `True` pin_memory: consult torch documentation, generally use `True` """ if early_stopping_patience is None: early_stopping_patience = max_epochs if self.optimizer is None: self.optimizer = optimizer(self.network.parameters(), lr=learning_rate) if scheduler is not None and self.scheduler is None: self.scheduler = scheduler(self.optimizer, **scheduler_kwargs) n_train, n_valid = get_ntrain_nvalid(validation_percentage, len(dataset)) dataset_train, dataset_valid = random_split(dataset, [n_train, n_valid]) train_loader = DataLoader( dataset_train, batch_size=batch_size, num_workers=nworkers, pin_memory=pin_memory, drop_last=True, ) valid_loader = DataLoader( dataset_valid, batch_size=min(batch_size, n_valid), num_workers=nworkers, pin_memory=pin_memory, drop_last=True, ) n_validation_batches = len(valid_loader) if len(valid_loader) != 0 else 1 validation_losses = [] self.epoch, fruitless_epoch = 0, 0 while self.epoch < max_epochs and fruitless_epoch < early_stopping_patience: # Training self.network.train() for observation, _, v in train_loader: self.optimizer.zero_grad() observation = swyft.utils.dict_to_device( observation, device=self.device, non_blocking=non_blocking ) v = v.to(self.device) loss = self._loss(observation, v).sum(dim=0) loss.backward() self.optimizer.step() self.epoch += 1 # Evaluation self.network.eval() loss_sum = 0 with torch.no_grad(): for observation, _, v in valid_loader: observation = swyft.utils.dict_to_device( observation, device=self.device, non_blocking=non_blocking ) v = v.to(self.device) validation_loss = self._loss(observation, v).sum(dim=0) loss_sum += validation_loss loss_avg = loss_sum / n_validation_batches print( "\rtraining: lr=%.2g, epoch=%i, validation loss=%.4g" % (self._get_last_lr(self.scheduler), self.epoch, loss_avg), end="", flush=True, ) validation_losses.append(loss_avg) if self.epoch == 0 or self.min_loss > loss_avg: fruitless_epoch = 0 self.min_loss = loss_avg self.best_network_state_dict = self.network.state_dict() else: fruitless_epoch += 1 if self.scheduler is not None: self.scheduler.step(loss_avg) print("") return validation_losses
[docs] def log_ratio( self, observation: ObsType, v: Array, batch_size: Optional[int] = None, ) -> MarginalToArray: """Evaluate the ratio estimator on a single `observation` with many `parameters`. The `parameters` correspond to `v`, i.e. the "physical" parameterization. (As opposed to `u` which is mapped to the hypercube.) Args: observation: a single observation to estimate ratios on (Cannot have a batch dimension!) v: parameters batch_size: divides the evaluation into batches of this size Returns: MarginalToArray: the ratios of each marginal in `marginal_indices`. Each marginal index is a key. """ was_training = self.network.training self.network.eval() with torch.no_grad(): observation = dict_array_to_tensor(observation, device=self.device) features = self.network.head( {key: value.unsqueeze(0) for key, value in observation.items()} ) len_v = len(v) if batch_size is None or len_v <= batch_size: v = array_to_tensor(v, device=self.device) repeated_features = features.expand(v.size(0), *features.shape[1:]) ratio = self.network.tail(repeated_features, v).cpu().numpy() else: ratio = [] for i in range(len_v // batch_size + 1): parameter_batch = array_to_tensor( v[i * batch_size : (i + 1) * batch_size, :], device=self.device, ) feature_batch = features.expand( parameter_batch.size(0), *features.shape[1:] ) ratio_batch = ( self.network.tail(feature_batch, parameter_batch).cpu().numpy() ) ratio.append(ratio_batch) ratio = np.vstack(ratio) if was_training: self.network.train() else: self.network.eval() return {k: ratio[..., i] for i, k in enumerate(self.marginal_indices)}
@staticmethod def _repeat_observation_to_match_v( observation: Dict[Hashable, torch.Tensor], v: torch.Tensor ) -> Dict[Hashable, torch.Tensor]: b, *_ = v.size() return valmap(lambda x: x.unsqueeze(0).expand(b, *x.size()), observation) @staticmethod def _get_last_lr(scheduler: SchedulerType) -> float: """Get the last learning rate from a `lr_scheduler`.""" if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): if scheduler.best == float("Inf"): return scheduler.optimizer.param_groups[0]["lr"] else: return scheduler._last_lr[-1] elif isinstance(scheduler, torch.optim.lr_scheduler._LRScheduler): return scheduler.get_last_lr() else: raise NotImplementedError( f"Cannot determine learning_rate from {scheduler}" ) def _loss( self, observation: Dict[Hashable, torch.Tensor], parameters: torch.Tensor ) -> torch.Tensor: """Calculate the marginal-wise losses. Args: observation: a batch of observations within a dictionary parameters: a batch of parameters Returns: torch.Tensor: the marginal-wise losses with `len(self.marginal_indices)` """ n_batch = parameters.size(0) assert ( n_batch % 2 == 0 ), "Loss function can only handle even-numbered batch sizes." assert all( [value.size(0) == n_batch for value in observation.values()] ), "The observation batch_size must agree with the parameter batch_size." # Repeat interleave observation_doubled = valmap(double_observation, observation) parameters_doubled = double_parameters(parameters) lnL = self.network(observation_doubled, parameters_doubled) lnL = lnL.view(-1, 4, lnL.shape[-1]) loss = -torch.nn.functional.logsigmoid(lnL[:, 0]) loss += -torch.nn.functional.logsigmoid(-lnL[:, 1]) loss += -torch.nn.functional.logsigmoid(-lnL[:, 2]) loss += -torch.nn.functional.logsigmoid(lnL[:, 3]) loss = loss.sum(axis=0) / n_batch return loss def state_dict(self) -> dict: return { "marginal_indices": self.marginal_indices, "network": self.network.state_dict(), "epoch": self.epoch, "min_loss": self.min_loss, "best_network_state_dict": self.best_network_state_dict, "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, "scheduler": self.scheduler.state_dict() if self.scheduler is not None else None, } # TODO this could also save the __class__ and use importlib, thereby reducing the arguments to from_state_dict
[docs] @classmethod def from_state_dict( cls, network: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer], scheduler: Optional[SchedulerType], device: Device, state_dict: dict, ) -> MarginalRatioEstimatorType: """Instantiate a MarginalRatioEstimator from a state_dict, along with a few necessary python objects. Args: network: initialized network optimizer: same optimizer as used by saved model scheduler: same scheduler as used by saved model device state_dict Returns: MarginalRatioEstimatorType: loaded model """ marginal_ratio_estimator = cls.__new__(cls) marginal_ratio_estimator.marginal_indices = state_dict["marginal_indices"] marginal_ratio_estimator.epoch = state_dict["epoch"] marginal_ratio_estimator.min_loss = state_dict["min_loss"] marginal_ratio_estimator.best_network_state_dict = state_dict[ "best_network_state_dict" ] marginal_ratio_estimator.device = device marginal_ratio_estimator.network = network marginal_ratio_estimator.optimizer = optimizer marginal_ratio_estimator.scheduler = scheduler marginal_ratio_estimator.network.load_state_dict(state_dict["network"]) if optimizer is not None and state_dict["optimizer"] is None: raise FileNotFoundError( "There was no data about the optimizer in the state_dict" ) elif optimizer is not None and state_dict["optimizer"] is not None: marginal_ratio_estimator.optimizer.load_state_dict(state_dict["optimizer"]) else: pass if scheduler is not None and state_dict["scheduler"] is None: raise FileNotFoundError( "There was no data about the scheduler in the state_dict" ) elif scheduler is not None and state_dict["scheduler"] is not None: marginal_ratio_estimator.scheduler.load_state_dict(state_dict["scheduler"]) else: pass return marginal_ratio_estimator
@classmethod def load( cls: Type[StateDictSaveableType], network: torch.nn.Module, device: Device, filename: PathType, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[SchedulerType] = None, ) -> StateDictSaveableType: sd = torch.load(filename) return cls.from_state_dict( network=network, optimizer=optimizer, scheduler=scheduler, device=device, state_dict=sd, )
if __name__ == "__main__": pass