from typing import Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
from ..C import COLOR_HIT_BOTH_BOUNDS, COLOR_HIT_NO_BOUNDS, COLOR_HIT_ONE_BOUND
from ..ensemble import Ensemble
[docs]def ensemble_identifiability(
ensemble: Ensemble,
ax: Optional[plt.Axes] = None,
size: Optional[Tuple[float]] = (12, 6),
):
"""
Visualize identifiablity of parameter ensemble.
Plot an overview about how many parameters hit the parameter bounds based
on a ensemble of parameters. confidence intervals/credible ranges are
computed via the ensemble mean plus/minus 1 standard deviation.
This highlevel routine expects a ensemble object as input.
Parameters
----------
ensemble:
ensemble of parameter vectors (from pypesto.ensemble)
ax:
Axes object to use.
size:
Figure size (width, height) in inches. Is only applied when no ax
object is specified
Returns
-------
ax: matplotlib.Axes
The plot axes.
"""
# first get the data to check identifiability
id_df = ensemble.check_identifiability()
# check how many bounds are actually hit and which ones
none_hit, lb_hit, ub_hit, both_hit = _prepare_identifiability_plot(id_df)
# call lowlevel routine whick works with np arrays only
ax = ensemble_identifiability_lowlevel(
none_hit, lb_hit, ub_hit, both_hit, ax, size
)
return ax
def ensemble_identifiability_lowlevel(
none_hit: np.ndarray,
lb_hit: np.ndarray,
ub_hit: np.ndarray,
both_hit: np.ndarray,
ax: Optional[plt.Axes] = None,
size: Optional[Tuple[float]] = (16, 10),
):
"""
Low-level identifiablity routine.
Plot an overview about how many parameters hit the parameter bounds based
on a ensemble of parameters. Confidence intervals/credible ranges are
computed via the ensemble mean plus/minus 1 standard deviation.
This lowlevel routine works with numpy arrays which define the confidence
intervals/credible ranges of each parameter.
Parameters
----------
none_hit:
2-dimensional array of confidence interval/credible ranges for
identifiable parameters
lb_hit:
2-dimensional array of confidence interval/credible ranges for
parameters which hit the lower parameter bound
ub_hit:
2-dimensional array of confidence interval/credible ranges for
parameters which hit the upper parameter bound
both_hit:
2-dimensional array of confidence interval/credible ranges for
parameters which hit both parameter bounds
ax:
Axes object to use.
size:
Figure size (width, height) in inches. Is only applied when no ax
object is specified
Returns
-------
ax: matplotlib.Axes
The plot axes.
"""
# define some short hands for later plotting
n_par = sum(
[
none_hit.shape[0],
lb_hit.shape[0],
ub_hit.shape[0],
both_hit.shape[0],
]
)
x_both = len(both_hit) / n_par
x_lb = len(lb_hit) / n_par
x_ub = len(ub_hit) / n_par
x_none = 1.0 - x_both - x_ub - x_lb
(
patches_both_hit,
patches_lb_hit,
patches_ub_hit,
patches_none_hit,
) = _create_patches(none_hit, lb_hit, ub_hit, both_hit)
# axes
if ax is None:
ax = plt.subplots()[1]
fig = plt.gcf()
fig.set_size_inches(*size)
# create axes object and add patch collections
if patches_both_hit:
ax.add_collection(patches_both_hit)
if patches_lb_hit:
ax.add_collection(patches_lb_hit)
if patches_ub_hit:
ax.add_collection(patches_ub_hit)
if patches_none_hit:
ax.add_collection(patches_none_hit)
# plot dashed lines indicating the number rof non-identifiable parameters
vert = [-0.05, 1.05]
ax.plot([x_both, x_both], vert, 'k--', linewidth=1.5)
ax.plot([x_both + x_lb, x_both + x_lb], vert, 'k--', linewidth=1.5)
ax.plot(
[x_both + x_lb + x_ub, x_both + x_lb + x_ub],
vert,
'k--',
linewidth=1.5,
)
# add text
if patches_both_hit:
ax.text(
x_both / 2,
-0.05,
'both bounds hit',
color=COLOR_HIT_BOTH_BOUNDS,
rotation=-90,
va='top',
ha='center',
)
if patches_lb_hit:
ax.text(
x_both + x_lb / 2,
-0.05,
'lower bound hit',
color=COLOR_HIT_ONE_BOUND,
rotation=-90,
va='top',
ha='center',
)
if patches_ub_hit:
ax.text(
x_both + x_lb + x_ub / 2,
-0.05,
'upper bound hit',
color=COLOR_HIT_ONE_BOUND,
rotation=-90,
va='top',
ha='center',
)
if patches_none_hit:
ax.text(
1 - x_none / 2,
-0.05,
'no bounds hit',
color=COLOR_HIT_NO_BOUNDS,
rotation=-90,
va='top',
ha='center',
)
ax.text(
0,
-0.7,
'identifiable parameters: {:4.1f}%'.format(x_none * 100),
va='top',
)
# plot upper and lower bounds
ax.text(-0.03, 1.0, 'upper\nbound', ha='right', va='center')
ax.text(-0.03, 0.0, 'lower\nbound', ha='right', va='center')
ax.plot([-0.02, 1.03], [0, 0], 'k:', linewidth=1.5)
ax.plot([-0.02, 1.03], [1, 1], 'k:', linewidth=1.5)
plt.xticks([])
plt.yticks([])
# plot frame
ax.plot([0, 0], vert, 'k-', linewidth=1.5)
ax.plot([1, 1], vert, 'k-', linewidth=1.5)
# beautify axes
plt.xlim((-0.15, 1.1))
plt.ylim((-0.78, 1.15))
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['top'].set_visible(False)
return ax
def _prepare_identifiability_plot(id_df: pd.DataFrame):
"""
Group model parameters based on an ensemble object .
Can group into four categories, based on the mean of the parameter
ensemble plus/minus 1 standard deviation: Parameters that hit both
bounds, parameters that hit only the lower [or upper] bound,
and parameters that hit no bounds. It returns them as four numpy arrays,
together with their confidence intervals/credible ranges.
Parameters
----------
id_df:
Pandas dataframe with information about parameter identifiability,
as created by pypesto.ensemble.check_identifiability()
Returns
-------
none_hit:
2-dimensional array of confidence interval/credible ranges for
identifiable parameters
lb_hit:
2-dimensional array of confidence interval/credible ranges for
parameters which hit the lower parameter bound
ub_hit:
2-dimensional array of confidence interval/credible ranges for
parameters which hit the upper parameter bound
both_hit:
2-dimensional array of confidence interval/credible ranges for
parameters which hit both parameter bounds
"""
# prepare
both_hit = []
lb_hit = []
ub_hit = []
none_hit = []
def _affine_transform(par_info):
# rescale parameters to bounds
lb = par_info['lowerBound']
ub = par_info['upperBound']
val_l = par_info['ensemble_mean'] - par_info['ensemble_std']
val_u = par_info['ensemble_mean'] + par_info['ensemble_std']
# check if parameter confidence intervals/credible ranges hit bound
if val_l <= lb:
lower_val = 0.0
else:
lower_val = (val_l - lb) / (ub - lb)
if val_u >= ub:
upper_val = 1.0
else:
upper_val = (val_u - lb) / (ub - lb)
return lower_val, upper_val
for par_id in list(id_df.index):
# check which of the parameters seems to be identifiable and group them
if (
id_df.loc[par_id, 'within lb: 1 std']
and id_df.loc[par_id, 'within ub: 1 std']
):
none_hit.append(_affine_transform(id_df.loc[par_id, :]))
elif id_df.loc[par_id, 'within lb: 1 std']:
ub_hit.append(_affine_transform(id_df.loc[par_id, :]))
elif id_df.loc[par_id, 'within ub: 1 std']:
lb_hit.append(_affine_transform(id_df.loc[par_id, :]))
else:
both_hit.append(_affine_transform(id_df.loc[par_id, :]))
return (
np.array(none_hit),
np.array(lb_hit),
np.array(ub_hit),
np.array(both_hit),
)
def _create_patches(
none_hit: np.ndarray,
lb_hit: np.ndarray,
ub_hit: np.ndarray,
both_hit: np.ndarray,
):
"""
Create patches for identifiability analysis.
Create matplotlib.patches.PatchCollection objects from numpy arrays with
confidence intervals/credible ranges, which visualize identifiability
properties of the model parameters.
Parameters
----------
none_hit:
2-dimensional array of confidence interval/credible ranges for
identifiable parameters
lb_hit:
2-dimensional array of confidence interval/credible ranges for
parameters which hit the lower parameter bound
ub_hit:
2-dimensional array of confidence interval/credible ranges for
parameters which hit the upper parameter bound
both_hit:
2-dimensional array of confidence interval/credible ranges for
parameters which hit both parameter bounds
Returns
-------
patches_both_hit:
patches showing parameters which hit both parameter bounds in the
ensemble (and are hence non-identifiable)
patches_lb_hit:
patches showing parameters which hit only the lower parameter bounds
in the ensemble (and are hence non-identifiable)
patches_ub_hit:
patches showing parameters which hit only the lower parameter bounds
in the ensemble (and are hence non-identifiable)
patches_none_hit
patches showing parameters which hit no parameter bounds in the
ensemble (and are hence identifiable)
"""
# get total number of parameters
n_par = sum(
[
none_hit.shape[0],
lb_hit.shape[0],
ub_hit.shape[0],
both_hit.shape[0],
]
)
# start patches at the left end and increment by h = 1/n_par
x = 0.0
h = 1.0 / n_par
# creates patches for parameters which hit both bounds
patches_both_hit = []
if both_hit.size > 0:
for _ in both_hit:
# create a list of rectangles
patches_both_hit.append(Rectangle((x, 0.0), h, 1.0))
x += h
patches_both_hit = PatchCollection(
patches_both_hit, facecolors=COLOR_HIT_BOTH_BOUNDS
)
# creates patches for parameters which hit lower bound
patches_lb_hit = []
# sort by normalizes length of confidence interval/credible range
if lb_hit.size > 0:
tmp_lb = np.sort(lb_hit[:, 1])[::-1]
for lb_par in tmp_lb:
# create a list of rectangles
patches_lb_hit.append(Rectangle((x, 0.0), h, lb_par))
x += h
patches_lb_hit = PatchCollection(
patches_lb_hit, facecolors=COLOR_HIT_ONE_BOUND
)
# creates patches for parameters which hit upper bound
patches_ub_hit = []
# sort by normalizes length of confidence interval/credible range
if ub_hit.size > 0:
tmp_ub = np.sort(ub_hit[:, 0])
for ub_par in tmp_ub:
# create a list of rectangles
patches_ub_hit.append(Rectangle((x, ub_par), h, 1.0 - ub_par))
x += h
patches_ub_hit = PatchCollection(
patches_ub_hit, facecolors=COLOR_HIT_ONE_BOUND
)
# creates patches for parameters which hit no bounds
patches_none_hit = []
# sort by normalizes length of confidence interval/credible range
if none_hit.size > 0:
tmp_none = np.argsort(none_hit[:, 1] - none_hit[:, 0])[::-1]
for none_par in tmp_none:
patches_none_hit.append(
# create a list of rectangles
Rectangle(
(x, none_hit[none_par, 0]),
h,
none_hit[none_par, 1] - none_hit[none_par, 0],
)
)
x += h
patches_none_hit = PatchCollection(
patches_none_hit, facecolors=COLOR_HIT_NO_BOUNDS
)
return patches_both_hit, patches_lb_hit, patches_ub_hit, patches_none_hit