Source code for swyft.plot.violin

from typing import Optional, Sequence

import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes

from swyft.types import MarginalToArray
from swyft.utils.marginals import filter_marginals_by_dim


def create_violin_df_from_marginal_dict(
    marginals: MarginalToArray, method: Optional[str] = None
) -> pd.DataFrame:
    """map from a marginal sample dict to the df format for violin plots

    Args:
        marginals: marginal dictionary
        method: name of method used to estimate posterior

    Returns:
        violin dataframe
    """
    marginals_1d = filter_marginals_by_dim(marginals, 1)
    rows = []
    for key, value in marginals_1d.items():
        data = {}
        data["Marginal"] = [key[0]] * len(value)
        data["Parameter"] = value.flatten()
        data["Method"] = [method] * len(value)
        df = pd.DataFrame.from_dict(data)
        rows.append(df)
    return pd.concat(rows, ignore_index=True)


[docs]def violin( marginals: MarginalToArray, axes: Axes = None, palette: str = "muted", labels: Optional[Sequence[str]] = None, ) -> Axes: """create a seaborn violin plot Args: marginals: marginals from the estimator, must be samples (NOT weighted samples) axes: matplotlib axes palette: seaborn palette labels: the string labels for the parameters. Returns: Axes """ ax = sns.violinplot( x="Marginal", y="Parameter", data=create_violin_df_from_marginal_dict(marginals), palette=palette, split=True, scale="width", inner="quartile", ax=axes, ) if labels is not None: ax.set_xticklabels(labels) return ax