Source code for pypesto.visualize.misc

from __future__ import annotations

import logging
import warnings
from collections.abc import Iterable
from numbers import Number

import matplotlib.axes
import matplotlib.pyplot as plt
import numpy as np

from ..C import (
    ALL,
    ALL_CLUSTERED,
    COLOR,
    FIRST_CLUSTER,
    FREE_ONLY,
    LEN_RGB,
    LEN_RGBA,
    RGB,
    RGB_RGBA,
    RGBA_ALPHA,
    RGBA_MAX,
    RGBA_MIN,
    RGBA_WHITE,
)
from ..result import Result
from ..util import assign_clusters, delete_nan_inf
from .clust_color import assign_colors_for_list

logger = logging.getLogger(__name__)


[docs] def process_result_list( results: Result | list[Result], colors: COLOR | list[COLOR] | np.ndarray | None = None, legends: str | list[str] | None = None, ) -> tuple[list[Result], list[COLOR], list[str]]: """ Assign colors and legends to a list of results, check user provided lists. Parameters ---------- results: list of pypesto.Result objects or a single pypesto.Result colors: list of colors recognized by matplotlib, or single color legends: labels for line plots Returns ------- results: list of pypesto.Result objects colors: One for each element in 'results'. legends: labels for line plots """ # check how many results were passed single_result = False legend_type_error = False if isinstance(results, list): if len(results) == 1: single_result = True else: single_result = True results = [results] # handle results according to their number if single_result: # assign colors and create list for later handling if colors is not None and isinstance(colors, list): colors = [np.array(colors)] else: colors = [colors] # create list of legends for later handling if not isinstance(legends, list): legends = [legends] try: str(legends[0]) except TypeError: legend_type_error = True else: # if more than one result is passed, we use one color per result colors = assign_colors_for_list(len(results), colors) # check whether list of legends has the correct length if legends is None: # No legends were passed: create some custom legends legends = [] for i_leg in range(len(results)): legends.append("Result " + str(i_leg)) else: # legends were passed by user: check length try: if isinstance(legends, str): legends = [legends] if len(legends) != len(results): raise ValueError( "List of results passed and list of labels do " "not have the same length." ) except TypeError: legend_type_error = True if legend_type_error: raise TypeError("Unexpected legend type.") return results, colors, legends
[docs] def process_offset_y( offset_y: float | None, scale_y: str, min_val: float ) -> float: """ Compute offset for y-axis, depend on user settings. Parameters ---------- offset_y: value for offsetting the later plotted values, in order to ensure positivity if a semilog-plot is used scale_y: Can be 'lin' or 'log10', specifying whether values should be plotted on linear or on log10-scale min_val: Smallest value to be plotted Returns ------- offset_y: float value for offsetting the later plotted values, in order to ensure positivity if a semilog-plot is used """ # check whether the offset specified by the user is sufficient if offset_y is not None: if (scale_y == "log10") and (min_val + offset_y <= 0.0): warnings.warn( "Offset specified by user is insufficient. " "Ignoring specified offset and using " + str(np.abs(min_val) + 1.0) + " instead.", stacklevel=2, ) else: return offset_y else: # check whether scaling is lin or log10 if scale_y == "lin": # linear scaling doesn't need any offset return 0.0 return 1.0 - min_val
[docs] def process_y_limits( ax: matplotlib.axes.Axes, y_limits: None | Iterable[float] | np.ndarray, ) -> matplotlib.axes.Axes: """ Apply user specified limits of y-axis. Parameters ---------- ax: Axes object to use. y_limits: y_limits, minimum and maximum, for current axes object Returns ------- ax: Axes object to use. """ # apply y-limits, if they were specified by the user if y_limits is not None: y_limits = np.array(y_limits) # check validity of bounds if y_limits.size == 0: y_limits = np.array(ax.get_ylim()) elif y_limits.size == 1: # if the user specified only an upper bound tmp_y_limits = ax.get_ylim() y_limits = [tmp_y_limits[0], y_limits] elif y_limits.size > 1: y_limits = [y_limits[0], y_limits[1]] # check validity of bounds if plotting in log-scale if ax.get_yscale() == "log" and y_limits[0] <= 0.0: tmp_y_limits = ax.get_ylim() if y_limits[1] <= 0.0: y_limits = tmp_y_limits warnings.warn( "Invalid bounds for plotting in " "log-scale. Using defaults bounds.", stacklevel=2, ) else: y_limits = [tmp_y_limits[0], y_limits[1]] warnings.warn( "Invalid lower bound for plotting in " "log-scale. Using only upper bound.", stacklevel=2, ) # set limits ax.set_ylim(y_limits) else: # No limits passed, but if we have a result list: check the limits ax_limits = np.array(ax.get_ylim()) data_limits = ax.dataLim.ymin, ax.dataLim.ymax # Check if data fits to axes and adapt limits, if necessary if np.isfinite(data_limits).all() and ( ax_limits[0] > data_limits[0] or ax_limits[1] < data_limits[1] ): # Get range of data data_range = data_limits[1] - data_limits[0] if ax.get_yscale() == "log": data_range = np.log10(data_range) new_limits = ( np.power(10, np.log10(data_limits[0]) - 0.02 * data_range), np.power(10, np.log10(data_limits[1]) + 0.02 * data_range), ) else: new_limits = ( data_limits[0] - 0.02 * data_range, data_limits[1] + 0.02 * data_range, ) # set limits ax.set_ylim(new_limits) return ax
def rgba2rgb(fg: RGB_RGBA, bg: RGB_RGBA = None) -> RGB: """Combine two colors, removing transparency. Parameters ---------- fg: Foreground color. bg: Background color. Returns ------- The combined color. """ if bg is None: bg = RGBA_WHITE if len(bg) == LEN_RGBA: # return foreground if background is fully transparent if bg[RGBA_ALPHA] == RGBA_MIN: return fg else: if len(bg) != LEN_RGB: raise IndexError( "A background color of unexpected length was provided: {bg}" ) bg = (*bg, RGBA_MAX) # return the foreground color if has no transparency if len(fg) == LEN_RGB or fg[RGBA_ALPHA] == RGBA_MAX: return fg if len(fg) != LEN_RGBA: raise IndexError( "A foreground color of unexpected length was provided: {fg}" ) def apparent_composite_color_component( fg_component: float, bg_component: float, fg_alpha: float = fg[RGBA_ALPHA], bg_alpha: float = bg[RGBA_ALPHA], ) -> float: """ Composite a foreground over a background color component. Porter and Duff equations are used for alpha compositing. Parameters ---------- fg_component: The foreground color component. bg_component: The background color component. fg_alpha: The foreground color transparency/alpha component. bg_alpha: The background color transparency/alpha component. Returns ------- The component of the new color. """ return ( fg_component * fg_alpha + bg_component * bg_alpha * (RGBA_MAX - fg_alpha) ) / (fg_alpha + bg_alpha * (RGBA_MAX - fg_alpha)) return [ apparent_composite_color_component(fg[i], bg[i]) for i in range(LEN_RGB) ] def process_start_indices( result: Result, start_indices: str | int | Iterable[int] = None, ) -> np.ndarray: """ Process the start_indices. Create an array of indices if a number was provided, checks that the indices do not exceed the max_index and removes starts with non-finite fval. Parameters ---------- start_indices: list of indices or int specifying an endpoint of the sequence of indices. Furthermore the following strings are possible: * 'all', this is the default, using all start indices. * 'all_clustered', this includes the best start and all clusters with the size > 1. * 'first_cluster', includes all starts that belong to the first cluster. result: Result to determine maximum allowed length and/or clusters. """ if start_indices is None: start_indices = ALL if isinstance(start_indices, str): if start_indices == ALL: start_indices = np.asarray(range(len(result.optimize_result))) elif start_indices == ALL_CLUSTERED: clust_ind, clust_size = assign_clusters( delete_nan_inf(result.optimize_result.fval)[1] ) # get all clusters that have size >= 2 and cluster of best start: clust_gr2 = np.where(clust_size > 2)[0] clust_gr2 = ( np.append(clust_gr2, 0) if 0 not in clust_gr2 else clust_gr2 ) start_indices = np.concatenate( [np.where(clust_ind == i_clust)[0] for i_clust in clust_gr2] ) start_indices = start_indices elif start_indices == FIRST_CLUSTER: clust_ind = assign_clusters( delete_nan_inf(result.optimize_result.fval)[1] )[0] start_indices = np.where(clust_ind == 0)[0] else: raise ValueError( f"Permissible values for start_indices are {ALL}, " f"{ALL_CLUSTERED}, {FIRST_CLUSTER}, an integer or a " f"list of indices. Got {start_indices}." ) # if it is an integer n, select the first n starts if isinstance(start_indices, Number): start_indices = range(int(start_indices)) # filter out the indices that exceed the range of possible start indices start_indices = [ start_index for start_index in start_indices if start_index < len(result.optimize_result) ] # filter out the indices that are not finite start_indices_unfiltered = len(start_indices) start_indices = [ start_index for start_index in start_indices if np.isfinite(result.optimize_result[start_index].fval) ] if len(start_indices) != start_indices_unfiltered: logger.warning( "Some start indices were removed due to inf or nan function values." ) return np.asarray(start_indices, dtype=int) def process_parameter_indices( result: Result, parameter_indices: str | Iterable[int] = FREE_ONLY, ) -> list: """ Process the parameter indices, always returning a valid array. Create an array of indices depending on the string that is provided. Or returns the sequence in case a sequence was provided. Parameters ---------- result: Result to determine maximum allowed length and/or clusters. parameter_indices: list of indices or str specifying the desired indices. Default is `free_only`. Other option is 'all', which included all estimated and fixed parameters. """ if isinstance(parameter_indices, str): if parameter_indices == ALL: return list(range(0, result.problem.dim_full)) elif parameter_indices == FREE_ONLY: return result.problem.x_free_indices else: raise ValueError( "Permissible values for parameter_indices are " f"{ALL}, {FREE_ONLY} or a list of indices." ) return list(parameter_indices) def make_grid_shape(n_panels: int) -> tuple[int, int]: """ Return a near-square ``(nrows, ncols)`` grid for ``n_panels`` subplots. Parameters ---------- n_panels: Number of panels to arrange. Returns ------- nrows, ncols: Smallest grid with ``nrows * ncols >= n_panels`` and aspect ratio close to square. """ if n_panels < 1: raise ValueError("n_panels must be at least 1.") nrows = int(np.ceil(np.sqrt(n_panels))) ncols = int(np.ceil(n_panels / nrows)) return nrows, ncols def get_ax( ax: matplotlib.axes.Axes | None = None, size: tuple[float, float] | None = None, ) -> matplotlib.axes.Axes: """ Return an Axes, creating one of size ``size`` if ``ax`` is None. Parameters ---------- ax: Existing matplotlib Axes. If provided, returned unchanged. size: Figure size ``(width, height)`` in inches; only used when ``ax`` is None. If None, matplotlib's default figure size is used. Returns ------- ax: A matplotlib Axes. """ if ax is not None: return ax _, ax = plt.subplots(figsize=size, layout="constrained") return ax def get_axes_array( axes: matplotlib.axes.Axes | np.ndarray | None = None, nrows: int = 1, ncols: int = 1, size: tuple[float, float] | None = None, ) -> np.ndarray: """ Return a 2-D array of Axes, creating one if ``axes`` is None. Parameters ---------- axes: Existing matplotlib Axes grid. If provided, it is normalized to a 2-D object array and validated against ``(nrows, ncols)``. nrows, ncols: Expected grid shape. size: Figure size ``(width, height)`` in inches; only used when ``axes`` is None. Returns ------- axes: A 2-D NumPy object array containing matplotlib Axes. """ if axes is None: _, axes = plt.subplots( nrows, ncols, squeeze=False, figsize=size, layout="constrained", ) return axes axes_array = np.asarray(axes, dtype=object) if axes_array.ndim == 0: axes_array = axes_array.reshape(1, 1) elif axes_array.ndim == 1: if nrows == 1: axes_array = axes_array.reshape(1, ncols) elif ncols == 1: axes_array = axes_array.reshape(nrows, 1) else: raise ValueError(f"Pass `axes` with shape ({nrows}, {ncols}).") if axes_array.shape != (nrows, ncols): raise ValueError(f"Pass `axes` with shape ({nrows}, {ncols}).") return axes_array def hide_unused_axes( axes: np.ndarray, n_used: int | None = None, used_indices: Iterable[int] | None = None, clear: bool = False, ) -> np.ndarray: """ Hide unused axes in a 2-D grid and ensure used axes are visible. Parameters ---------- axes: 2-D NumPy array containing matplotlib Axes. n_used: Number of leading axes in ``axes.flat`` to keep visible. used_indices: Flat indices of the axes that should remain visible. clear: Whether to clear every axis before toggling visibility. Returns ------- axes: The same 2-D NumPy array with updated visibility. """ axes_array = np.asarray(axes, dtype=object) if axes_array.ndim != 2: raise ValueError("Pass `axes` as a 2-D NumPy array.") if (n_used is None) == (used_indices is None): raise ValueError("Pass exactly one of `n_used` or `used_indices`.") if used_indices is None: if n_used is None or not 0 <= n_used <= axes_array.size: raise ValueError( f"`n_used` must be between 0 and {axes_array.size}." ) visible_indices = set(range(n_used)) else: visible_indices = set(used_indices) invalid_indices = [ index for index in visible_indices if index < 0 or index >= axes_array.size ] if invalid_indices: raise ValueError( "Pass `used_indices` within the flattened axes range " f"[0, {axes_array.size - 1}]." ) for index, ax in enumerate(axes_array.flat): if clear: ax.clear() ax.set_visible(index in visible_indices) return axes_array def plot_diagonal_marginal( ax: matplotlib.axes.Axes, values: np.ndarray, diag_kind: str = "kde", color: str = "C0", ) -> None: """ Plot a 1-D marginal on a diagonal scatter-matrix panel. Parameters ---------- ax: Axes to draw into. values: One-dimensional sample values. diag_kind: Marginal visualization mode: ``"kde"`` or ``"hist"``. color: Base matplotlib color for the marginal. """ from scipy.stats import gaussian_kde values = np.asarray(values) if values.size == 0: return data_range = values.max() - values.min() if data_range == 0: data_range = max(abs(float(values.mean())) * 0.1, 0.1) x_pad = data_range * 0.25 x_grid = np.linspace(values.min() - x_pad, values.max() + x_pad, 300) if diag_kind == "kde" and len(values) > 1: try: kde = gaussian_kde(values) y_grid = kde(x_grid) ax.fill_between(x_grid, y_grid, alpha=0.35, color=color) ax.plot(x_grid, y_grid, color=color, lw=1.5) ax.set_ylabel("Density") return except np.linalg.LinAlgError: pass ax.hist(values, bins="auto", color=color, alpha=0.6) ax.set_ylabel("Count") #: Sentinel meaning "this kwarg was not passed at all." #: Use as the default for deprecated kwargs so that an explicit #: ``f(old_kwarg=None)`` can be detected and warned about. _UNSET = object() def process_deprecated_kwarg( canonical_name: str, canonical_value, deprecated_name: str, deprecated_value=_UNSET, stacklevel: int = 3, ): """ Resolve a kwarg that has been renamed. The deprecated kwarg must use :data:`_UNSET` as its default in the calling function so that an explicit ``f(old_kwarg=None)`` is correctly detected and warned about. Returns the canonical value if the deprecated kwarg was not passed, the deprecated value (with a ``DeprecationWarning``) if only the old name was used, or raises ``ValueError`` if both are given. Parameters ---------- canonical_name: Name of the canonical (new) kwarg, used in messages. canonical_value: Value passed under the canonical name (or ``None``). deprecated_name: Name of the deprecated (old) kwarg, used in messages. deprecated_value: Value passed under the deprecated name; defaults to :data:`_UNSET`. stacklevel: Forwarded to :func:`warnings.warn`. Default 3 attributes the warning to the caller of the public function that invoked this helper. Returns ------- value: The resolved value, or ``None`` if neither was given. """ if deprecated_value is _UNSET: return canonical_value if canonical_value is not None: raise ValueError( f"Pass either `{canonical_name}` or the deprecated " f"`{deprecated_name}`, not both." ) warnings.warn( f"`{deprecated_name}` is deprecated; use `{canonical_name}` instead.", DeprecationWarning, stacklevel=stacklevel, ) return deprecated_value