Source code for pypesto.visualize.ordinal_categories

import warnings
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
    import pypesto

import matplotlib.pyplot as plt
import numpy as np

try:
    import amici
    from petab.C import OBSERVABLE_ID

    from ..hierarchical.ordinal.calculator import OrdinalCalculator
    from ..hierarchical.ordinal.parameter import OrdinalParameter
    from ..hierarchical.ordinal.solver import (
        compute_interval_constraints,
        get_bounds_for_category,
        undo_inner_parameter_reparameterization,
    )
except ImportError:
    pass


from ..C import (
    AMICI_SIGMAY,
    AMICI_T,
    AMICI_Y,
    CENSORED,
    MEASUREMENT_TYPE,
    ORDINAL,
    QUANTITATIVE_DATA,
    QUANTITATIVE_IXS,
    REPARAMETERIZED,
    SCIPY_X,
)
from ..result import Result


[docs] def plot_categories_from_pypesto_result( pypesto_result: Result, start_index=0, axes: Optional[plt.Axes] = None, **kwargs, ): """Plot the inner solutions from a pypesto result. Parameters ---------- pypesto_result: The pypesto result. start_index: The index of the pypesto_result.optimize_result.list to plot. axes: The optional axes to plot on. kwargs: Additional arguments to pass to the figure. Returns ------- fig: The figure. axes: The axes. """ # Get the parameters from the pypesto result for the start_index. x_dct = dict( zip( pypesto_result.problem.objective.x_ids, pypesto_result.optimize_result.list[start_index]["x"], ) ) x_dct.update( pypesto_result.problem.objective.calculator.necessary_par_dummy_values ) # Get the needed objects from the pypesto problem. edatas = pypesto_result.problem.objective.edatas parameter_mapping = pypesto_result.problem.objective.parameter_mapping amici_model = pypesto_result.problem.objective.amici_model amici_solver = pypesto_result.problem.objective.amici_solver petab_problem = ( pypesto_result.problem.objective.amici_object_builder.petab_problem ) n_threads = pypesto_result.problem.objective.n_threads # Fill in the parameters. amici.parameter_mapping.fill_in_parameters( edatas=edatas, problem_parameters=x_dct, scaled_parameters=True, parameter_mapping=parameter_mapping, amici_model=amici_model, ) # Simulate the model with the parameters from the pypesto result. inner_rdatas = amici.runAmiciSimulations( amici_model, amici_solver, edatas, num_threads=min(n_threads, len(edatas)), ) # If any amici simulation failed, raise warning and return None. if any(rdata.status != amici.AMICI_SUCCESS for rdata in inner_rdatas): warnings.warn( "Warning: Some AMICI simulations failed. Cannot plot inner " "solutions.", stacklevel=2, ) return None # Get simulation and sigma. sim = [rdata[AMICI_Y] for rdata in inner_rdatas] sigma = [rdata[AMICI_SIGMAY] for rdata in inner_rdatas] timepoints = [rdata[AMICI_T] for rdata in inner_rdatas] observable_ids = amici_model.getObservableIds() condition_ids = [edata.id for edata in edatas] petab_condition_ordering = list(petab_problem.condition_df.index) # Get the observable ordering from the measurement_df. measurement_df_observable_ordering = list( petab_problem.measurement_df[OBSERVABLE_ID].unique() ) optimal_scaling_calculator = None for ( calculator ) in pypesto_result.problem.objective.calculator.inner_calculators: if isinstance(calculator, OrdinalCalculator): optimal_scaling_calculator = calculator break # Get the inner solver and problem. inner_solver = optimal_scaling_calculator.inner_solver inner_problem = optimal_scaling_calculator.inner_problem inner_results = inner_solver.solve(inner_problem, sim, sigma) return plot_categories_from_inner_result( inner_problem, inner_solver, inner_results, sim, timepoints, observable_ids, condition_ids, petab_condition_ordering, measurement_df_observable_ordering, axes, **kwargs, )
[docs] def plot_categories_from_inner_result( inner_problem: "pypesto.hierarchical.ordinal.problem.OrdinalProblem", inner_solver: "pypesto.hierarchical.ordinal.solver.OrdinalInnerSolver", results: list[dict], simulation: list[np.ndarray], timepoints: list[np.ndarray], observable_ids: list[str] = None, condition_ids: list[str] = None, petab_condition_ordering: list[str] = None, measurement_df_observable_ordering: list[str] = None, axes: Optional[plt.Axes] = None, **kwargs, ): """Plot the inner solutions. Parameters ---------- inner_problem: The inner problem. inner_solver: The inner solver. results: The results from the inner solver. simulation: The model simulation. timepoints: The timepoints of the simulation. kwargs: Additional arguments to pass to the figure. axes: The optional axes to plot on. Returns ------- fig: The figure. axes: The axes. """ if len(results) != len(inner_problem.groups): raise ValueError( "Number of results must be equal to number of groups of the inner subproblem." ) # Get the number of groups n_groups = len(inner_problem.groups) options = inner_solver.options use_given_axes = axes is not None # If there are no axes, make a figure with multiple plots if axes is None: axes = _get_default_axes(n_groups, **kwargs) # for each result and group, plot the inner solution for result, group in zip(results, inner_problem.groups): if observable_ids is not None and use_given_axes: observable_id = observable_ids[group - 1] meas_obs_idx = measurement_df_observable_ordering.index( observable_id ) # Get the ax for the current observable. ax = axes["plot" + str(meas_obs_idx + 1)] else: ax = axes[list(inner_problem.groups.keys()).index(group)] # For each group get the inner parameters and simulation xs = inner_problem.get_cat_ub_parameters_for_group(group) interval_range, interval_gap = compute_interval_constraints( xs, simulation, options ) observable_index = group - 1 measurement_type = inner_problem.groups[group][MEASUREMENT_TYPE] # Get surrogate datapoints and category bounds ( simulation_all, surrogate_all, timepoints_all, upper_bounds_all, lower_bounds_all, ) = _get_data_for_plotting( xs, result[SCIPY_X], simulation, timepoints, interval_range, interval_gap, options, measurement_type, ) # Get the number of distinct timepoints in timepoints_all # where timepoints_all is a list of numpy arrays of timepoints n_distinct_timepoints = len(np.unique(np.concatenate(timepoints_all))) # If there is only one distinct timepoint, plot with respect to conditions if n_distinct_timepoints == 1 and not use_given_axes: _plot_observable_fit_across_conditions( ax, inner_problem, observable_index, group, condition_ids, simulation, simulation_all, surrogate_all, upper_bounds_all, lower_bounds_all, measurement_type, petab_condition_ordering, use_given_axes, ) # Plotting across timepoints elif n_distinct_timepoints > 1: n_conditions = len(simulation_all) # If there is only one condition, we don't need # separate colors for the different conditions if n_conditions == 1: _plot_observable_fit_for_one_condition( ax, observable_index, group, inner_problem, timepoints, timepoints_all, simulation, simulation_all, surrogate_all, lower_bounds_all, upper_bounds_all, measurement_type, use_given_axes, ) # If there are multiple conditions, we need # separate colors for the different conditions elif n_conditions > 1: _plot_observable_fit_for_multiple_conditions( ax, observable_index, group, inner_problem, timepoints, timepoints_all, simulation, simulation_all, surrogate_all, lower_bounds_all, upper_bounds_all, measurement_type, condition_ids, use_given_axes, ) ax.legend() if not use_given_axes: ax.set_title(f"Group {group}, {measurement_type} data") ax.set_xlabel("Timepoints") ax.set_ylabel("Simulation/Surrogate data") if not use_given_axes: for ax in axes[len(results) :]: ax.remove() return axes
def _plot_category_rectangles_across_conditions( ax, category_timepoints_dict, unique_timepoints ) -> None: for ( upper_bound, lower_bound, ), timepoints in category_timepoints_dict.items(): # If the largest timepoint is not the last unique timepoint, add the next unique timepoint # to the list of timepoints max_timepoint_unique_ind = np.where( unique_timepoints == max(timepoints) )[0][0] if max_timepoint_unique_ind + 1 < len(unique_timepoints): timepoints.append(unique_timepoints[max_timepoint_unique_ind + 1]) # Plot the category rectangle ax.fill_between( timepoints, [upper_bound] * len(timepoints), [lower_bound] * len(timepoints), color="gray", alpha=0.5, ) # Add to legend meaning of gray rectangles. ax.fill_between( [], [], [], color="gray", alpha=0.5, label="Categories", ) def _plot_category_rectangles( ax, timepoints, upper_bounds, lower_bounds, surrogate_data, measurement_type, ) -> None: """Plot the category rectangles.""" interval_length = 0 for i in range(len(timepoints)): if i + 1 == len(timepoints) or upper_bounds[i + 1] != upper_bounds[i]: if i + 1 == len(timepoints): if upper_bounds[i] == np.inf: upper_bounds[i - interval_length : i + 1] = 1.1 * max( surrogate_data ) middle_index = int((i - interval_length + i) / 2) middle_timepoint = timepoints[middle_index] # Draw a vertical short grey arrow at the middle point of the interval # at the upper_bounds[i] height ax.annotate( "", xy=(middle_timepoint, upper_bounds[i]), xytext=( middle_timepoint, upper_bounds[i] + 0.1 * max(surrogate_data), ), arrowprops={ "arrowstyle": "<-", "color": "gray", "linewidth": 2, }, ) ax.text( middle_timepoint, upper_bounds[i] + 0.1 * max(surrogate_data), "INF", color="gray", fontsize=12, ) # Extend the ax to contain the text ax.set_ylim( bottom=ax.get_ylim()[0], top=max( ax.get_ylim()[1], upper_bounds[i] + 0.1 * max(surrogate_data), ), ) ax.fill_between( timepoints[i - interval_length : i + 1], upper_bounds[i - interval_length : i + 1], lower_bounds[i - interval_length : i + 1], color="gray", alpha=0.5, ) else: if upper_bounds[i] == np.inf: upper_bounds[i - interval_length : i + 1] = 1.1 * max( surrogate_data ) middle_index = int((i - interval_length + i + 1) / 2) middle_timepoint = timepoints[middle_index] # Draw a vertical short grey arrow at the middle point of the interval # at the upper_bounds[i] height ax.annotate( "", xy=(middle_timepoint, upper_bounds[i]), xytext=( middle_timepoint, upper_bounds[i] + 0.1 * max(surrogate_data), ), arrowprops={ "arrowstyle": "<-", "color": "gray", "linewidth": 2, }, ) ax.text( middle_timepoint, upper_bounds[i] + 0.1 * max(surrogate_data), "INF", color="gray", fontsize=12, ) # Extend the ax to contain the text ax.set_ylim( bottom=ax.get_ylim()[0], top=max( ax.get_ylim()[1], upper_bounds[i] + 0.1 * max(surrogate_data), ), ) ax.fill_between( timepoints[i - interval_length : i + 2], np.concatenate( ( upper_bounds[i - interval_length : i + 1], [upper_bounds[i]], ) ), np.concatenate( ( lower_bounds[i - interval_length : i + 1], [lower_bounds[i]], ) ), color="gray", alpha=0.5, ) interval_length = 0 else: interval_length += 1 if measurement_type == ORDINAL: # Add to legend meaning of rectangles ax.fill_between( [], [], [], color="gray", alpha=0.5, label="Categories", ) elif measurement_type == CENSORED: # Add to legend meaning of rectangles ax.fill_between( [], [], [], color="gray", alpha=0.5, label="Censoring areas", ) def _get_data_for_plotting( inner_parameters: list["OrdinalParameter"], optimal_scaling_bounds: list, sim: list[np.ndarray], timepoints: list[np.ndarray], interval_range: float, interval_gap: float, options: dict, measurement_type: str, ): """Return data in the form suited for plotting.""" if options[REPARAMETERIZED] and measurement_type == ORDINAL: optimal_scaling_bounds = undo_inner_parameter_reparameterization( optimal_scaling_bounds, inner_parameters, interval_gap, interval_range, ) simulation_all = [] surrogate_all = [] timepoints_all = [] upper_bounds_all = [] lower_bounds_all = [] for condition_index in range(len(sim)): cond_simulation = [] cond_surrogate = [] cond_timepoints = [] cond_upper_bounds = [] cond_lower_bounds = [] for inner_parameter in inner_parameters: if measurement_type == ORDINAL: upper_bound, lower_bound = get_bounds_for_category( inner_parameter, optimal_scaling_bounds, interval_gap, options, ) elif measurement_type == CENSORED: x_category = inner_parameter.category lower_bound = optimal_scaling_bounds[2 * x_category - 2] upper_bound = optimal_scaling_bounds[2 * x_category - 1] # Get the condition specific simulation, mask, and timepoints sim_i = sim[condition_index] mask_i = inner_parameter.ixs[condition_index] t_i = timepoints[condition_index] y_sim = sim_i[mask_i] # If there is no measurement in this # condition for this category, skip it if len(y_sim) == 0: continue if mask_i.ndim == 1: t_sim = t_i[mask_i] else: observable_index = [ i for i in range(len(mask_i.T)) if any(mask_i.T[i]) ][0] t_sim = timepoints[condition_index][mask_i.T[observable_index]] for y_sim_i in y_sim: if lower_bound > y_sim_i: y_surrogate = lower_bound elif y_sim_i > upper_bound: y_surrogate = upper_bound elif lower_bound <= y_sim_i <= upper_bound: y_surrogate = y_sim_i else: continue cond_surrogate.append(y_surrogate) cond_upper_bounds.append(upper_bound) cond_lower_bounds.append(lower_bound) cond_simulation.extend(y_sim) cond_timepoints.extend(t_sim) # Sort the surrogate datapoints and categories by timepoints, ascending. cond_simulation = np.array(cond_simulation) cond_surrogate = np.array(cond_surrogate) cond_timepoints = np.array(cond_timepoints) cond_upper_bounds = np.array(cond_upper_bounds) cond_lower_bounds = np.array(cond_lower_bounds) sort_idx = np.argsort(cond_timepoints) cond_simulation = cond_simulation[sort_idx] cond_surrogate = cond_surrogate[sort_idx] cond_timepoints = cond_timepoints[sort_idx] cond_upper_bounds = cond_upper_bounds[sort_idx] cond_lower_bounds = cond_lower_bounds[sort_idx] # Add the condition surrogate datapoints and categories to the list of all conditions. simulation_all.append(cond_simulation) surrogate_all.append(cond_surrogate) timepoints_all.append(cond_timepoints) upper_bounds_all.append(cond_upper_bounds) lower_bounds_all.append(cond_lower_bounds) return ( simulation_all, surrogate_all, timepoints_all, upper_bounds_all, lower_bounds_all, ) def _get_default_axes(n_groups, **kwargs): """Return a list of axes with the default layout.""" # If there is only one group, make a figure with only one plot if n_groups == 1: # Make figure with only one plot fig, ax = plt.subplots(1, 1, **kwargs) axes = [ax] # If there are multiple groups, make a figure with multiple plots else: # Choose number of rows and columns to be used for the subplots n_rows = int(np.ceil(np.sqrt(n_groups))) n_cols = int(np.ceil(n_groups / n_rows)) # Make as many subplots as there are groups fig, axes = plt.subplots(n_rows, n_cols, **kwargs) # Increase the spacing between the subplots fig.subplots_adjust(hspace=0.35, wspace=0.25) # Flatten the axes array axes = axes.flatten() return axes def _plot_observable_fit_across_conditions( ax, inner_problem, observable_index, group, condition_ids, simulation, simulation_all, surrogate_all, upper_bounds_all, lower_bounds_all, measurement_type, condition_ids_from_petab, use_given_axes, ): """Plot the observable fit across conditions. In case the observable has only one timepoint, the observable fit will be plotted against the conditions. """ if measurement_type == CENSORED: # Get the condition indices which have censored data # and the corresponding condition ids with their ordering censored_condition_ids = [ condition_ids[i] for i, cond_sim in enumerate(simulation_all) if len(cond_sim) > 0 ] petab_censored_conditions = [ condition_id for condition_id in condition_ids_from_petab if condition_id in censored_condition_ids ] petab_censored_conditions_ordering = [ censored_condition_ids.index(condition_id) for condition_id in petab_censored_conditions ] # Get all other condition indices for quantitative data # and the corresponding condition ids with their ordering quantitative_condition_ids = [ condition_id for condition_id in condition_ids if condition_id not in censored_condition_ids ] petab_quantitative_conditions = [ condition_id for condition_id in condition_ids_from_petab if condition_id in quantitative_condition_ids ] petab_quantitative_condition_ordering = [ quantitative_condition_ids.index(condition_id) for condition_id in petab_quantitative_conditions ] petab_condition_ordering = [ condition_ids.index(condition_id) for condition_id in condition_ids_from_petab ] # Merge the simulation, surrogate, and bounds across conditions simulation_all = np.concatenate(simulation_all) surrogate_all = np.concatenate(surrogate_all) upper_bounds_all = np.concatenate(upper_bounds_all) lower_bounds_all = np.concatenate(lower_bounds_all) if measurement_type == CENSORED: # Change ordering of simulation, surrogate data and bounds to petab condition ordering simulation_all = simulation_all[petab_censored_conditions_ordering] surrogate_all = surrogate_all[petab_censored_conditions_ordering] upper_bounds_all = upper_bounds_all[petab_censored_conditions_ordering] lower_bounds_all = lower_bounds_all[petab_censored_conditions_ordering] whole_simulation = np.concatenate( [sim_i[:, observable_index] for sim_i in simulation] )[petab_condition_ordering] if not use_given_axes: ax.plot( condition_ids_from_petab, whole_simulation, linestyle="-", marker=".", color="b", label="Simulation", ) ax.plot( petab_censored_conditions, surrogate_all, "rx", label="Surrogate data", ) _plot_category_rectangles( ax, petab_censored_conditions, upper_bounds_all, lower_bounds_all, surrogate_all, measurement_type, ) quantitative_data = inner_problem.groups[group][QUANTITATIVE_DATA] quantitative_data = quantitative_data[ petab_quantitative_condition_ordering ] ax.plot( petab_quantitative_conditions, quantitative_data, "gs", label="Quantitative data", ) elif measurement_type == ORDINAL: # Change ordering of simulation, surrogate data and bounds to petab condition ordering simulation_all = simulation_all[petab_condition_ordering] surrogate_all = surrogate_all[petab_condition_ordering] upper_bounds_all = upper_bounds_all[petab_condition_ordering] lower_bounds_all = lower_bounds_all[petab_condition_ordering] # Plot the categories and surrogate data across conditions if not use_given_axes: ax.plot( condition_ids_from_petab, simulation_all, linestyle="-", marker=".", color="b", label="Simulation", ) ax.plot( condition_ids_from_petab, surrogate_all, "rx", label="Surrogate data", ) _plot_category_rectangles( ax, condition_ids_from_petab, upper_bounds_all, lower_bounds_all, surrogate_all, measurement_type, ) # Set the condition xticks on an angle ax.tick_params(axis="x", rotation=25) ax.legend() if not use_given_axes: ax.set_title(f"Group {group}, {measurement_type} data") ax.set_xlabel("Conditions") ax.set_ylabel("Simulation/Surrogate data") def _plot_observable_fit_for_one_condition( ax, observable_index, group, inner_problem, timepoints, timepoints_all, simulation, simulation_all, surrogate_all, lower_bounds_all, upper_bounds_all, measurement_type, use_given_axes, ): """Plot the observable fit in case it has one condition.""" if measurement_type == ORDINAL: if not use_given_axes: ax.plot( timepoints_all[0], simulation_all[0], linestyle="-", marker=".", color="b", label="Simulation", ) elif measurement_type == CENSORED: quantitative_data = inner_problem.groups[group][QUANTITATIVE_DATA] quantitative_ixs = inner_problem.groups[group][QUANTITATIVE_IXS] quantitative_timepoints = timepoints[0][ quantitative_ixs[0].T[observable_index] ] if not use_given_axes: ax.plot( timepoints[0], simulation[0][:, observable_index], linestyle="-", marker=".", color="b", label="Simulation", ) ax.plot( quantitative_timepoints, quantitative_data, "gs", label="Quantitative data", ) ax.plot( timepoints_all[0], surrogate_all[0], "rx", label="Surrogate data", ) # Plot the categorie rectangles _plot_category_rectangles( ax, timepoints_all[0], upper_bounds_all[0], lower_bounds_all[0], surrogate_all[0], measurement_type, ) def _plot_observable_fit_for_multiple_conditions( ax, observable_index, group, inner_problem, timepoints, timepoints_all, simulation, simulation_all, surrogate_all, lower_bounds_all, upper_bounds_all, measurement_type, condition_ids, use_given_axes, ): """Plot the observable fit in case it has multiple conditions.""" # Get the colors from the plotted simulations if use_given_axes: colors = [] for line in ax.lines: if "simulation" in line.get_label(): colors.append(line.get_color()) # Get as many colors as there are conditions else: colors = plt.cm.rainbow(np.linspace(0, 1, len(simulation_all))) if measurement_type == CENSORED: quantitative_data_flattened = inner_problem.groups[group][ QUANTITATIVE_DATA ] quantitative_ixs = inner_problem.groups[group][QUANTITATIVE_IXS] quantitative_timepoints = [ timepoints[cond_i][quantitative_ixs[cond_i].T[observable_index]] for cond_i in range(len(timepoints)) ] quantitative_data = [] index_offset = 0 # Separate quantitative data across conditions to be as timepoints for cond_i in range(len(timepoints)): quantitative_data.append( quantitative_data_flattened[ index_offset : index_offset + len(quantitative_timepoints[cond_i]) ] ) # Plot the categories and surrogate data for all conditions. for condition_index, condition_id, color in zip( range(len(simulation_all)), condition_ids, colors ): # Plot the categories and surrogate data for the current condition if measurement_type == ORDINAL: if not use_given_axes: ax.plot( timepoints_all[condition_index], simulation_all[condition_index], linestyle="-", marker=".", color=color, label=condition_id, ) elif measurement_type == CENSORED: if not use_given_axes: ax.plot( timepoints[condition_index], simulation[condition_index][:, observable_index], linestyle="-", marker=".", color=color, label=condition_id, ) ax.plot( quantitative_timepoints[condition_index], quantitative_data[condition_index], marker="s", color=color, ) ax.plot( timepoints_all[condition_index], surrogate_all[condition_index], "x", color=color, ) # Get all unique timepoints in ascending order unique_timepoints = np.unique(np.concatenate(timepoints_all)) # Gather timepoints for each category in a dictionary # with upper, lower bound tuple as key and list of timepoints as value category_timepoints_dict = {} for condition_idx in range(len(simulation_all)): for upper_bound, lower_bound, timepoint in zip( upper_bounds_all[condition_idx], lower_bounds_all[condition_idx], timepoints_all[condition_idx], ): if ( upper_bound, lower_bound, ) not in category_timepoints_dict: category_timepoints_dict[(upper_bound, lower_bound)] = [ timepoint ] else: category_timepoints_dict[(upper_bound, lower_bound)].append( timepoint ) # Plot the category rectangles _plot_category_rectangles_across_conditions( ax, category_timepoints_dict, unique_timepoints, ) # Add to legend meaning of x, and -o- markers. ax.plot( [], [], "x", color="black", label="Surrogate data", ) if not use_given_axes: ax.plot( [], [], linestyle="-", marker=".", color="black", label="Simulation", ) if measurement_type == CENSORED: ax.plot( [], [], marker="s", color="black", label="Quantitative data", )