Source code for swyft.weightedmarginals

from dataclasses import dataclass
from typing import Tuple, TypeVar

import numpy as np
import pandas as pd

from swyft.types import (
    MarginalIndex,
    MarginalToArray,
    MarginalToDataFrame,
    StrictMarginalIndex,
)
from swyft.utils.marginals import filter_marginals_by_dim, tupleize_marginal_indices

WeightedMarginalSamplesType = TypeVar(
    "WeightedMarginalSamplesType", bound="WeightedMarginalSamples"
)


[docs]@dataclass class WeightedMarginalSamples: weights: MarginalToArray v: np.ndarray @staticmethod def _select_marginal_index(marginal_index: MarginalIndex) -> Tuple[int, ...]: marginal_index = tupleize_marginal_indices(marginal_index) assert ( len(marginal_index) == 1 ), "weighted marginal samples can only be recovered one index at a time" return marginal_index[0]
[docs] def get_logweight(self, marginal_index: MarginalIndex) -> np.ndarray: """access the logweight for a certain marginal by marginal_index Args: marginal_index: which marginal to select. one at at time. Returns: logweight """ marginal_index = self._select_marginal_index(marginal_index) logweight = self.weights[marginal_index] return logweight
[docs] def get_logweight_marginal( self, marginal_index: MarginalIndex ) -> Tuple[np.ndarray, np.ndarray]: """access the logweight and parameter values for a marginal by index Args: marginal_index: which marginal to select. one at a time. Returns: logweight, marginal: the logweight and the parameter values """ marginal_index = self._select_marginal_index(marginal_index) logweight = self.get_logweight(marginal_index) marginal = self.v[:, marginal_index] return logweight, marginal
[docs] def get_df(self, marginal_index: MarginalIndex) -> pd.DataFrame: """convert a weighted marginal into a dataframe with the marginal_indices, 'weight', and 'logweight' as columns Args: marginal_index: which marginal to select. one at a time. Returns: DataFrame with marginal_indices, 'weight', and 'logweight' for columns """ marginal_index = self._select_marginal_index(marginal_index) logweight, marginal = self.get_logweight_marginal(marginal_index) weight = np.exp(logweight) data = np.concatenate( [marginal, weight[..., None], logweight[..., None]], axis=-1 ) columns = list(marginal_index) + ["weight"] + ["logweight"] return pd.DataFrame(data=data, columns=columns)
[docs] def get_df_dict(self) -> MarginalToDataFrame: """produce a map from marginal_index to df for all dfs and marginal_indices""" return { marginal_index: self.get_df(marginal_index) for marginal_index in self.marginal_indices }
def filter_by_dim(self, dim: int) -> WeightedMarginalSamplesType: weights = filter_marginals_by_dim(self.weights, dim) return self.__class__(weights, self.v) @property def marginal_indices(self) -> StrictMarginalIndex: return tuple(self.weights.keys())