Source code for swyft.plot.plot2

import numpy as np
import pylab as plt
from scipy.integrate import simps
from scipy.ndimage import gaussian_filter, gaussian_filter1d
import swyft
import swyft.lightning.utils

from typing import (
    Callable,
    Dict,
    Hashable,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
    Any,
)


def grid_interpolate_samples(x, y, bins=1000, return_norm=False):
    idx = np.argsort(x)
    x, y = x[idx], y[idx]
    x_grid = np.linspace(x[0], x[-1], bins)
    y_grid = np.interp(x_grid, x, y)
    norm = simps(y_grid, x_grid)
    y_grid_normed = y_grid / norm
    if return_norm:
        return x_grid, y_grid_normed, norm
    else:
        return x_grid, y_grid_normed


def get_HDI_thresholds(x, cred_level=[0.68268, 0.95450, 0.99730]):
    x = x.flatten()
    x = np.sort(x)[::-1]  # Sort backwards
    total_mass = x.sum()
    enclosed_mass = np.cumsum(x)
    idx = [np.argmax(enclosed_mass >= total_mass * f) for f in cred_level]
    levels = np.array(x[idx])
    return levels


[docs]def plot_2d( logratios, parname1, parname2, ax=None, bins=100, color="k", cmap="gray_r", smooth=0.0, ): """Plot 2-dimensional posteriors.""" counts, xy = swyft.lightning.utils.get_pdf( logratios, [parname1, parname2], bins=bins, smooth=smooth ) xbins = xy[:, 0] ybins = xy[:, 1] # if not isinstance(logratios, list): # logratios = [logratios,] # # samples = None # for s in logratios: # weighted_samples = s.get_matching_weighted_samples(parname1, parname2) # if weighted_samples is not None: # samples, weights = weighted_samples # if samples is None: # return if ax is None: ax = plt.gca() # # FIXME: use interpolation when grid_interpolate == True # x = samples[:,0].numpy() # y = samples[:,1].numpy() # w = weights.numpy() # counts, xbins, ybins, _ = ax.hist2d(x, y, weights=w, bins=bins, cmap=cmap) # if smooth is not None: # counts = gaussian_filter(counts, smooth) levels = sorted(get_HDI_thresholds(counts)) ax.contour( counts.T, extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()], levels=levels, linestyles=[":", "--", "-"], colors=color, ) ax.imshow( counts.T, extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()], cmap=cmap, origin="lower", aspect="auto", ) ax.set_xlim([xbins.min(), xbins.max()]) ax.set_ylim([ybins.min(), ybins.max()])
# xm = (xbins[:-1] + xbins[1:]) / 2 # ym = (ybins[:-1] + ybins[1:]) / 2 # # cx = counts.sum(axis=1) # cy = counts.sum(axis=0) # # mean = (sum(xm * cx) / sum(cx), sum(ym * cy) / sum(cy)) # # return dict(mean=mean, mode=None, HDI1=None, HDI2=None, HDI3=None, entropy=None)
[docs]def plot_1d( logratios, parname, weights_key=None, ax=None, grid_interpolate=False, bins=100, color="k", contours=True, smooth=0.0, ): """Plot 1-dimensional posteriors.""" # samples, weights, = swyft.get_weighted_samples(logratios, parname) # if not isinstance(logratios, list): # logratios = [logratios,] # # samples = None # for s in logratios: # weighted_samples = s.get_matching_weighted_samples(parname) # if weighted_samples is not None: # samples, weights = weighted_samples # if samples is None: # return v, zm = swyft.lightning.utils.get_pdf(logratios, parname, bins=bins, smooth=smooth) zm = zm[:, 0] # x = samples[:,0].numpy() # w = weights.numpy() # # v, e = np.histogram(x, weights=w, bins=bins, density=True) # zm = (e[1:] + e[:-1]) / 2 # if smooth is not None: # v = gaussian_filter1d(v, smooth) if ax is None: ax = plt.gca() levels = sorted(get_HDI_thresholds(v)) if contours: contour1d(zm, v, levels, ax=ax, color=color) ax.plot(zm, v, color=color) ax.set_xlim([zm.min(), zm.max()]) ax.set_ylim([-v.max() * 0.05, v.max() * 1.1])
# # Diagnostics # mean = sum(w * x) / sum(w) # mode = zm[v == v.max()][0] # int2 = zm[v > levels[2]].min(), zm[v > levels[2]].max() # int1 = zm[v > levels[1]].min(), zm[v > levels[1]].max() # int0 = zm[v > levels[0]].min(), zm[v > levels[0]].max() # entropy = -simps(v * np.log(v), zm) # return dict( # mean=mean, mode=mode, HDI1=int2, HDI2=int1, HDI3=int0, entropy=entropy # ) def plot_posterior( samples, pois, weights_key=None, ax=None, grid_interpolate=False, bins=100, color="k", contours=True, **kwargs ): if ax is None: ax = plt.gca() if isinstance(pois, int): pois = (pois,) w = None # FIXME: Clean up ad hoc code if weights_key is None: weights_key = tuple(sorted(pois)) try: w = samples["weights"][tuple(weights_key)] except KeyError: if len(weights_key) == 1: for k in samples["weights"].keys(): if weights_key[0] in k: weights_key = k break w = samples["weights"][tuple(weights_key)] elif len(weights_key) == 2: for k in samples["weights"].keys(): if set(weights_key).issubset(k): weights_key = k w = samples["weights"][k] if w is None: return if len(pois) == 1: x = samples["v"][:, pois[0]] if grid_interpolate: # Grid interpolate samples log_prior = samples["log_priors"][pois[0]] w_eff = np.exp(np.log(w) + log_prior) # p(z|x) = r(x, z) p(z) zm, v = grid_interpolate_samples(x, w_eff) else: v, e = np.histogram(x, weights=w, bins=bins, density=True) zm = (e[1:] + e[:-1]) / 2 levels = sorted(get_HDI_thresholds(v)) if contours: contour1d(zm, v, levels, ax=ax, color=color) ax.plot(zm, v, color=color, **kwargs) ax.set_xlim([x.min(), x.max()]) ax.set_ylim([-v.max() * 0.05, v.max() * 1.1]) # Diagnostics mean = sum(w * x) / sum(w) mode = zm[v == v.max()][0] int2 = zm[v > levels[2]].min(), zm[v > levels[2]].max() int1 = zm[v > levels[1]].min(), zm[v > levels[1]].max() int0 = zm[v > levels[0]].min(), zm[v > levels[0]].max() entropy = -simps(v * np.log(v), zm) return dict( mean=mean, mode=mode, HDI1=int2, HDI2=int1, HDI3=int0, entropy=entropy ) elif len(pois) == 2: # FIXME: use interpolation when grid_interpolate == True x = samples["v"][:, pois[0]] y = samples["v"][:, pois[1]] counts, xbins, ybins, _ = ax.hist2d(x, y, weights=w, bins=bins, cmap="gray_r") levels = sorted(get_HDI_thresholds(counts)) try: ax.contour( counts.T, extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()], levels=levels, linestyles=[":", "--", "-"], colors=color, ) except ValueError: print("WARNING: 2-dim contours not well-defined.") ax.set_xlim([x.min(), x.max()]) ax.set_ylim([y.min(), y.max()]) xm = (xbins[:-1] + xbins[1:]) / 2 ym = (ybins[:-1] + ybins[1:]) / 2 cx = counts.sum(axis=1) cy = counts.sum(axis=0) mean = (sum(xm * cx) / sum(cx), sum(ym * cy) / sum(cy)) return dict(mean=mean, mode=None, HDI1=None, HDI2=None, HDI3=None, entropy=None) def plot_1d_old( samples, pois, truth=None, bins=100, figsize=(15, 10), color="k", labels=None, label_args={}, ncol=None, subplots_kwargs={}, fig=None, contours=True, ) -> None: """Make beautiful 1-dim posteriors. Args: samples: Samples from `swyft.Posteriors.sample` pois: List of parameters of interest truth: Ground truth vector bins: Number of bins used for histograms. figsize: Size of figure color: Color labels: Custom labels (default is parameter names) label_args: Custom label arguments ncol: Number of panel columns subplot_kwargs: Subplot kwargs """ grid_interpolate = False diags = {} if ncol is None: ncol = len(pois) K = len(pois) nrow = (K - 1) // ncol + 1 if fig is None: fig, axes = plt.subplots(nrow, ncol, figsize=figsize, **subplots_kwargs) else: axes = fig.get_axes() lb = 0.125 tr = 0.9 whspace = 0.15 fig.subplots_adjust( left=lb, bottom=lb, right=tr, top=tr, wspace=whspace, hspace=whspace ) if labels is None: labels = [samples["parameter_names"][pois[i]] for i in range(K)] for k in range(K): if nrow == 1 and ncol > 1: ax = axes[k] elif nrow == 1 and ncol == 1: ax = axes else: i, j = k % ncol, k // ncol ax = axes[j, i] ret = plot_posterior( samples, pois[k], ax=ax, grid_interpolate=grid_interpolate, color=color, bins=bins, contours=contours, ) ax.set_xlabel(labels[k], **label_args) if truth is not None: ax.axvline(truth[pois[k]], ls=":", color="r") diags[(pois[k],)] = ret return fig, diags
[docs]def corner( logratios, parnames, bins=100, truth=None, figsize=(10, 10), color="k", labels=None, label_args={}, contours_1d: bool = True, fig=None, labeler=None, smooth=0.0, ) -> None: """Make a beautiful corner plot. Args: samples: Samples from `swyft.Posteriors.sample` pois: List of parameters of interest truth: Ground truth vector bins: Number of bins used for histograms. figsize: Size of figure color: Color labels: Custom labels (default is parameter names) label_args: Custom label arguments contours_1d: Plot 1-dim contours fig: Figure instance """ K = len(parnames) if fig is None: fig, axes = plt.subplots(K, K, figsize=figsize) else: axes = np.array(fig.get_axes()).reshape((K, K)) lb = 0.125 tr = 0.9 whspace = 0.1 fig.subplots_adjust( left=lb, bottom=lb, right=tr, top=tr, wspace=whspace, hspace=whspace ) diagnostics = {} if labeler is not None: labels = [labeler.get(k, k) for k in parnames] else: labels = parnames for i in range(K): for j in range(K): ax = axes[i, j] # Switch off upper left triangle if i < j: ax.set_yticklabels([]) ax.set_xticklabels([]) ax.set_xticks([]) ax.set_yticks([]) ax.set_frame_on(False) continue # Formatting labels if j > 0 or i == 0: ax.set_yticklabels([]) # ax.set_yticks([]) if i < K - 1: ax.set_xticklabels([]) # ax.set_xticks([]) if i == K - 1: ax.set_xlabel(labels[j], **label_args) if j == 0 and i > 0: ax.set_ylabel(labels[i], **label_args) # Set limits # ax.set_xlim(x_lims[j]) # if i != j: # ax.set_ylim(y_lims[i]) # 2-dim plots if j < i: try: ret = plot_2d( logratios, parnames[j], parnames[i], ax=ax, color=color, bins=bins, smooth=smooth, ) except swyft.SwyftParameterError: pass # if truth is not None: # ax.axvline(truth[parnames[j]], color="r") # ax.axhline(truth[parnames[i]], color="r") # diagnostics[(pois[j], pois[i])] = ret if j == i: try: ret = plot_1d( logratios, parnames[i], ax=ax, color=color, bins=bins, contours=contours_1d, smooth=smooth, ) except swyft.SwyftParameterError: pass # if truth is not None: # ax.axvline(truth[pois[i]], ls=":", color="r") # diagnostics[(pois[i],)] = ret return fig
def contour1d(z, v, levels, ax=plt, linestyles=None, color=None, **kwargs): y0 = -1.0 * v.max() y1 = 5.0 * v.max() ax.fill_between(z, y0, y1, where=v > levels[0], color=color, alpha=0.1) ax.fill_between(z, y0, y1, where=v > levels[1], color=color, alpha=0.1) ax.fill_between(z, y0, y1, where=v > levels[2], color=color, alpha=0.1) # if not isinstance(colors, list): # colors = [colors]*len(levels) # for i, l in enumerate(levels): # zero_crossings = np.where(np.diff(np.sign(v-l*1.001)))[0] # for c in z[zero_crossings]: # ax.axvline(c, ls=linestyles[i], color = colors[i], **kwargs) if __name__ == "__main__": pass
[docs]def plot_zz( coverage_samples, params: Union[str, Sequence[str]], z_max: float = 3.5, bins: int = 50, ax=None, ): """Make a zz plot. Args: coverage_samples: Collection of CoverageSamples object params: Parameters of interest z_max: Maximum value of z. bins: Number of discretization bins. """ cov = swyft.estimate_coverage(coverage_samples, params, z_max=z_max, bins=bins) ax = ax if ax else plt.gca() swyft.plot.mass.plot_empirical_z_score(ax, cov[:, 0], cov[:, 1], cov[:, 2:])
[docs]def plot_pp( coverage_samples, params: Union[str, Sequence[str]], z_max: float = 3.5, bins: int = 50, ax=None, ): """Make a pp plot.""" cov = swyft.estimate_coverage(coverage_samples, params, z_max=z_max, bins=bins) alphas = 1 - swyft.plot.mass.get_alpha(cov) ax = ax if ax else plt.gca() ax.fill_between(alphas[:, 0], alphas[:, 2], alphas[:, 3], color="0.8") ax.plot(alphas[:, 0], alphas[:, 1], "k") plt.plot([0, 1], [0, 1], "g--") plt.xlabel("Nominal credibility [$1-p$]") plt.ylabel("Empirical coverage [$1-p$]")
# swyft.plot.mass.plot_empirical_z_score(ax, cov[:,0], cov[:,1], cov[:,2:]) # def plot_scores(mass): # s = mass.get_z_scores() # for j, k in enumerate(s.keys()): # for i in [0, 1, 2]: # y = s[k][i,1] # yerr = np.array([s[k][i,2]-y, y-s[k][i,0]]).reshape(2, 1) # plt.errorbar(j, [y], yerr = yerr, marker='.', color='k') # labels = [list(v) for v in s.keys()] # plt.xticks(range(len(labels)), labels) # plt.axhline(1., color='k', ls=':') # plt.axhline(2., color='k', ls=':') # plt.axhline(3., color='k', ls=':')