Source code for swyft.plot.plot

import numpy as np
import pylab as plt
import swyft

from scipy.integrate import simps
from swyft.lightning.utils import get_pdf
from swyft.plot.mass import plot_empirical_z_score, get_alpha

from typing import (
    Sequence,
    Union,
)


#####################
# Auxiliary functions
#####################


def _grid_interpolate_samples(x, y, bins=1000, return_norm=False):
    idx = np.argsort(x)
    x, y = x[idx], y[idx]
    x_grid = np.linspace(x[0], x[-1], bins)
    y_grid = np.interp(x_grid, x, y)
    norm = simps(y_grid, x_grid)
    y_grid_normed = y_grid / norm
    if return_norm:
        return x_grid, y_grid_normed, norm
    else:
        return x_grid, y_grid_normed


def _get_HDI_thresholds(x, cred_level=[0.68268, 0.95450, 0.99730]):
    x = x.flatten()
    x = np.sort(x)[::-1]  # Sort backwards
    total_mass = x.sum()
    enclosed_mass = np.cumsum(x)
    idx = [np.argmax(enclosed_mass >= total_mass * f) for f in cred_level]
    levels = np.array(x[idx])
    return levels


def _contour1d(z, v, levels, ax=plt, linestyles=None, color=None, **kwargs):
    y0 = -1.0 * v.max()
    y1 = 5.0 * v.max()
    ax.fill_between(z, y0, y1, where=v >= levels[0], color=color, alpha=0.1)
    ax.fill_between(z, y0, y1, where=v >= levels[1], color=color, alpha=0.1)
    ax.fill_between(z, y0, y1, where=v >= levels[2], color=color, alpha=0.1)
    # if not isinstance(colors, list):
    #    colors = [colors]*len(levels)
    # for i, l in enumerate(levels):
    #    zero_crossings = np.where(np.diff(np.sign(v-l*1.001)))[0]
    #    for c in z[zero_crossings]:
    #        ax.axvline(c, ls=linestyles[i], color = colors[i], **kwargs)


def _plot_2d(
    lrs_coll,
    parname1,
    parname2,
    ax=None,
    bins=100,
    color="k",
    cmap="gray_r",
    smooth=0.0,
    cred_level=[0.68268, 0.95450, 0.99730],
    truth=None,
    smooth_prior=False,
):
    """Plot 2-dimensional posterior.

    Args:
        lrs_coll: Collection of swyft.LogRatioSamples objects
        parname1: Name of parameter 1
        parname2: Name of parameter 2
        ax: Optional figure axis argument
        bins: Number of bins used for histograms.
        color: Contour colors
        cmap: Density colors
        smooth: Applied smoothing factor
        cred_level: Credible levels for contours
        truth: Dictionary with parameters names as keys and true values
        smooth_prior: Smooth and histogram prior instead of posterior (default False)
    """
    counts, xy = get_pdf(
        lrs_coll,
        [parname1, parname2],
        bins=bins,
        smooth=smooth,
        smooth_prior=smooth_prior,
    )
    xbins = xy[:, 0]
    ybins = xy[:, 1]

    if ax is None:
        ax = plt.gca()

    #    # FIXME: use interpolation when grid_interpolate == True
    #    x = samples[:,0].numpy()
    #    y = samples[:,1].numpy()
    #    w = weights.numpy()
    #    counts, xbins, ybins, _ = ax.hist2d(x, y, weights=w, bins=bins, cmap=cmap)
    #    if smooth is not None:
    #        counts = gaussian_filter(counts, smooth)

    levels = sorted(_get_HDI_thresholds(counts, cred_level=cred_level))
    ax.contour(
        counts.T,
        extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()],
        levels=levels,
        linestyles=[":", "--", "-"],
        colors=color,
    )
    ax.imshow(
        counts.T,
        extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()],
        cmap=cmap,
        origin="lower",
        aspect="auto",
    )
    ax.set_xlim([xbins.min(), xbins.max()])
    ax.set_ylim([ybins.min(), ybins.max()])

    if truth is not None:
        if parname1 in truth.keys():
            ax.axvline(truth[parname1], color="k", lw=1.0, zorder=10, ls=(1, (5, 1)))
        if parname2 in truth.keys():
            ax.axhline(truth[parname2], color="k", lw=1.0, zorder=10, ls=(1, (5, 1)))
        if parname1 in truth.keys() and parname2 in truth.keys():
            ax.scatter(
                [truth[parname1]], [truth[parname2]], c="k", marker=".", s=100,
            )


#    xm = (xbins[:-1] + xbins[1:]) / 2
#    ym = (ybins[:-1] + ybins[1:]) / 2
#
#    cx = counts.sum(axis=1)
#    cy = counts.sum(axis=0)
#
#    mean = (sum(xm * cx) / sum(cx), sum(ym * cy) / sum(cy))
#
#    return dict(mean=mean, mode=None, HDI1=None, HDI2=None, HDI3=None, entropy=None)


def _plot_1d(
    lrs_coll,
    parname,
    ax=None,
    bins=100,
    color="k",
    contours=True,
    smooth=0.0,
    cred_level=[0.68268, 0.95450, 0.99730],
    truth=None,
    smooth_prior=False,
):
    """Plot 1-dimensional posteriors.

    Args:
        lrs_coll: Collection of swyft.LogRatioSamples objects
        parname: Name of parameter
        ax: Optional figure axis argument
        bins: Number of bins used for histograms.
        color: Contour colors
        contours: Indicate contours
        smooth: Applied smoothing factor
        cred_level: Credible levels for contours
        truth: Dictionary with parameters names as keys and true values
        smooth_prior: Smooth and histogram prior instead of posterior (default False)
    """

    v, zm = get_pdf(
        lrs_coll, parname, bins=bins, smooth=smooth, smooth_prior=smooth_prior
    )
    zm = zm[:, 0]

    if ax is None:
        ax = plt.gca()

    levels = sorted(_get_HDI_thresholds(v, cred_level=cred_level))
    if contours:
        _contour1d(zm, v, levels, ax=ax, color=color)
    ax.plot(zm, v, color=color)
    ax.set_xlim([zm.min(), zm.max()])
    ax.set_ylim([-v.max() * 0.05, v.max() * 1.1])

    if truth is not None:
        ax.axvline(truth[parname], color="k", lw=1.0, zorder=10, ls=(1, (5, 1)))


#####
# API
#####


[docs] def plot_corner( lrs_coll, parnames, bins=100, figsize=None, color="k", labels=None, label_args={}, contours_1d: bool = True, fig=None, smooth=0.0, cred_level=[0.68268, 0.95450, 0.99730], truth=None, smooth_prior=False, # plot_diagonal=True # TODO: Implement supression of diagonals ) -> None: """Make a beautiful corner plot. Args: lrs_coll: Collection of swyft.LogRatioSamples objects parnames: List of parameters of interest bins: Number of bins used for histograms. figsize: Size of figure color: Color labels: Optional custom labels, either list or dict. label_args: Custom label arguments contours_1d: Plot 1-dim contours fig: Figure instance smooth: histogram smoothing cred_level: Credible levels for contours truth: Dictionary with parameters names as keys and true values smooth_prior: Smooth and histogram prior instead of posterior (default False) """ K = len(parnames) if fig is None: fig, axes = plt.subplots(K, K, figsize=figsize) else: axes = np.array(fig.get_axes()).reshape((K, K)) # lb = 0.125 # tr = 0.9 # whspace = 0.1 # fig.subplots_adjust( # left=lb, bottom=lb, right=tr, top=tr, wspace=whspace, hspace=whspace # ) # # diagnostics = {} if labels is None: labels = parnames elif isinstance(labels, list): assert len(labels) == len( parnames ), "Length of labels list must correspond to number of parameters." elif isinstance(labels, dict): labels = [labels.get(k, k) for k in parnames] else: raise ValueError("labels must be None, list or dict") for i in range(K): for j in range(K): ax = axes[i, j] # Switch off upper left triangle if i < j: ax.set_yticklabels([]) ax.set_xticklabels([]) ax.set_xticks([]) ax.set_yticks([]) ax.set_frame_on(False) continue # Formatting labels if j > 0 or i == 0: ax.set_yticklabels([]) # ax.set_yticks([]) if i < K - 1: ax.set_xticklabels([]) # ax.set_xticks([]) if i == K - 1: ax.set_xlabel(labels[j], **label_args) if j == 0 and i > 0: ax.set_ylabel(labels[i], **label_args) # Set limits # ax.set_xlim(x_lims[j]) # if i != j: # ax.set_ylim(y_lims[i]) # 2-dim plots if j < i: try: ret = _plot_2d( lrs_coll, parnames[j], parnames[i], ax=ax, color=color, bins=bins, smooth=smooth, cred_level=cred_level, truth=truth, smooth_prior=smooth_prior, ) except swyft.SwyftParameterError: pass if j == i: try: ret = _plot_1d( lrs_coll, parnames[i], ax=ax, color=color, bins=bins, contours=contours_1d, smooth=smooth, truth=truth, smooth_prior=smooth_prior, ) except swyft.SwyftParameterError: pass # Tight things up fig.tight_layout() return fig
[docs] def plot_zz( coverage_samples, params: Union[str, Sequence[str]], z_max: float = 3.5, bins: int = 50, ax=None, ): """Make a zz plot. Args: coverage_samples: Collection of CoverageSamples object params: Parameters of interest z_max: Maximum value of z. bins: Number of discretization bins. ax: Optional axes instance. """ cov = swyft.estimate_coverage(coverage_samples, params, z_max=z_max, bins=bins) ax = ax if ax else plt.gca() plot_empirical_z_score(ax, cov[:, 0], cov[:, 1], cov[:, 2:])
[docs] def plot_pp( coverage_samples, params: Union[str, Sequence[str]], z_max: float = 3.5, bins: int = 50, ax=None, ): """Make a pp plot. Args: coverage_samples: Collection of CoverageSamples object params: Parameters of interest z_max: Maximum value of z. bins: Number of discretization bins. ax: Optional axes instance. """ cov = swyft.estimate_coverage(coverage_samples, params, z_max=z_max, bins=bins) alphas = 1 - get_alpha(cov) ax = ax if ax else plt.gca() ax.fill_between(alphas[:, 0], alphas[:, 2], alphas[:, 3], color="0.8") ax.plot(alphas[:, 0], alphas[:, 1], "k") plt.plot([0, 1], [0, 1], "g--") plt.xlabel("Nominal credibility [$1-p$]") plt.ylabel("Empirical coverage [$1-p$]")
[docs] def plot_posterior( lrs_coll, parnames=None, bins=100, figsize=None, color="k", labels=None, label_args={}, ncol=None, subplots_kwargs={}, fig=None, contours=True, smooth=1.0, cred_level=[0.68268, 0.95450, 0.99730], truth=None, smooth_prior=False, ) -> None: """Make beautiful 1-dim posteriors. Args: lrs_coll: Collection of swyft.LogRatioSamples objects parnames: (Optional) List of parameters of interest bins: Number of bins used for histograms. figsize: Optional size of figure color: Color labels: (Optional) Custom labels label_args: (Pptional) Custom label arguments ncol: (Optional) Number of panel columns subplots_kwargs: Optional arguments for subplots generation. fig: Optional figure instance contours: Plot 1-dim contours smooth: Gaussian smothing scale cred_level: Credible levels for contours truth: (Optional) Dictionary with parameters names as keys and true values smooth_prior: Smooth and histogram prior instead of posterior (default False) """ # parnames should be single str or list of strings if isinstance(parnames, str): parnames = [parnames] # labels can be None (defaulting to parnames), or list of names of dictionary mapping parnames on labels if labels is None: labels = parnames elif isinstance(labels, list): assert len(labels) == len( parnames ), "Length of labels list must correspond to number of parameters." elif isinstance(labels, dict): labels = [labels.get(k, k) for k in parnames] else: raise ValueError("labels must be None, list or dict") # If ncol is None, default to (max) 4 panels per row if ncol is None: ncol = min(len(parnames), 4) K = len(parnames) nrow = (K - 1) // ncol + 1 if fig is None: fig, axes = plt.subplots(nrow, ncol, figsize=figsize, **subplots_kwargs) else: axes = fig.get_axes() # Ensure axes has always the same shape if isinstance(axes, np.ndarray): axes = axes.reshape(-1) elif isinstance(axes, list): axes = np.array(axes) else: axes = np.array([axes]) ncol = nrow = 1 for k in range(ncol * nrow): ax = axes[k] if k >= K: ax.set_visible(False) continue _plot_1d( lrs_coll, parnames[k], ax=ax, bins=bins, color=color, contours=contours, smooth=smooth, cred_level=cred_level, truth=truth, smooth_prior=smooth_prior, ) ax.set_xlabel(labels[k], **label_args) ax.set_yticks([]) # ax.tick_params(axis='x', which='minor', bottom = True) ax.minorticks_on() # Tight things up fig.tight_layout() return fig
[docs] def plot_pair( lrs_coll, parnames=None, bins=100, figsize=None, color="k", labels=None, label_args={}, ncol=None, subplots_kwargs={}, fig=None, smooth=1.0, cred_level=[0.68268, 0.95450, 0.99730], truth=None, smooth_prior=False, ) -> None: """Make beautiful 2-dim posteriors. Args: lrs_coll: Collection of swyft.LogRatioSamples objects parnames: (Optional) List of parameter pairs of interest bins: Number of bins used for histograms. figsize: Optional size of figure color: Color labels: (Optional) Custom labels label_args: (Pptional) Custom label arguments ncol: (Optional) Number of panel columns subplots_kwargs: Optional arguments for subplots generation. fig: Optional figure instance smooth: Gaussian smothing scale cred_level: Credible levels for contours truth: (Optional) Dictionary with parameters names as keys and true values smooth_prior: Smooth and histogram prior instead of posterior (default False) """ # parnames should be single str or list of strings if isinstance(parnames[0], str): parnames = [parnames] # labels can be None (defaulting to parnames), or list of names of dictionary mapping parnames on labels if labels is None: labels = parnames elif isinstance(labels, dict): labels = [ [l[i].get(k[i], k[i]) for i in [0, 1]] for l, k in zip(labels, parnames) ] else: raise ValueError("labels must be None or dict") # If ncol is None, default to (max) 4 panels per row if ncol is None: ncol = min(len(parnames), 4) K = len(parnames) nrow = (K - 1) // ncol + 1 if fig is None: fig, axes = plt.subplots(nrow, ncol, figsize=figsize, **subplots_kwargs) else: axes = fig.get_axes() # Ensure axes has always the same shape if isinstance(axes, np.ndarray): axes = axes.reshape(-1) else: axes = np.array([axes]) ncol = nrow = 1 for k in range(ncol * nrow): ax = axes[k] if k >= K: ax.set_visible(False) continue _plot_2d( lrs_coll, parnames[k][0], parnames[k][1], ax=ax, bins=bins, color=color, smooth=smooth, cred_level=cred_level, truth=truth, smooth_prior=smooth_prior, ) ax.set_xlabel(labels[k][0], **label_args) ax.set_ylabel(labels[k][1], **label_args) # ax.set_yticks([]) # ax.tick_params(axis='x', which='minor', bottom = True) ax.minorticks_on() # Tight things up fig.tight_layout()