Source code for pypesto.visualize.dimension_reduction

from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np

from .clust_color import RGBA

if TYPE_CHECKING:
    try:
        import umap

        UmapTypeObject = umap.umap_.UMAP
    except ImportError:
        UmapTypeObject = None


[docs]def projection_scatter_umap( umap_coordinates: np.ndarray, components: Sequence[int] = (0, 1), **kwargs ): """ Plot a scatter plots for UMAP coordinates. Creates either one or multiple scatter plots, depending on the number of coordinates passed to it. Parameters ---------- umap_coordinates: array of umap coordinates (returned as first output by the routine get_umap_representation) to be shown as scatter plot components: Components to be plotted (corresponds to columns of umap_coordinates) Returns ------- axs: Either one axes object, or a dictionary of plot axes (depending on the number of coordinates passed) """ n_components = len(components) if n_components == 2: # handle components x_label = f'UMAP component {components[0] + 1}' y_label = f'UMAP component {components[1] + 1}' dataset = umap_coordinates[:, components] # call lowlevel routine return ensemble_scatter_lowlevel( dataset, x_label=x_label, y_label=y_label, **kwargs ) else: # We got more than two components. Plot a cross-classification table # Create the labels first component_labels = [ f'UMAP component {components[i_comp] + 1}' for i_comp in range(n_components) ] # reduce pca components dataset = umap_coordinates[:, components] # run lowlevel plot return ensemble_crosstab_scatter_lowlevel( dataset, component_labels, **kwargs )
[docs]def projection_scatter_umap_original( umap_object: UmapTypeObject, color_by: Sequence[float] = None, components: Sequence[int] = (0, 1), **kwargs, ): """ See `projection_scatter_umap` for more documentation. Wrapper around umap.plot.points. Similar to `projection_scatter_umap`, but uses the original plotting routine from umap.plot. Parameters ---------- umap_object: umap object (returned as second output by get_umap_representation) to be shown as scatter plot color_by: A sequence/list of floats, which specify the color in the colormap components: Components to be plotted (corresponds to columns of umap_coordinates) Returns ------- ax: matplotlib.Axes The plot axes. """ import umap.plot # reduce, if necessary umap_object.embedding_ = umap_object.embedding_[:, components] # use umap's original plotting routine to visualize umap.plot.points(umap_object, values=color_by, theme='viridis', **kwargs)
[docs]def projection_scatter_pca( pca_coordinates: np.ndarray, components: Sequence[int] = (0, 1), **kwargs ): """ Plot a scatter plot for PCA coordinates. Creates either one or multiple scatter plots, depending on the number of coordinates passed to it. Parameters ---------- pca_coordinates: array of pca coordinates (returned as first output by the routine get_pca_representation) to be shown as scatter plot components: Components to be plotted (corresponds to columns of pca_coordinates) Returns ------- axs: Either one axes object, or a dictionary of plot axes (depending on the number of coordinates passed) """ n_components = len(components) if n_components == 2: # handle components x_label = f'PCA component {components[0] + 1}' y_label = f'PCA component {components[1] + 1}' dataset = pca_coordinates[:, components] # call lowlevel routine return ensemble_scatter_lowlevel( dataset, x_label=x_label, y_label=y_label, **kwargs ) else: # We got more than two components. Plot a cross-classification table # Create the labels first component_labels = [ f'PCA component {components[i_comp] + 1}' for i_comp in range(n_components) ] # reduce pca components dataset = pca_coordinates[:, components] # run lowlevel plot return ensemble_crosstab_scatter_lowlevel( dataset, component_labels, **kwargs )
[docs]def ensemble_crosstab_scatter_lowlevel( dataset: np.ndarray, component_labels: Sequence[str] = None, **kwargs ): """ Plot cross-classification table of scatter plots for different coordinates. Lowlevel routine for multiple UMAP and PCA plots, but can also be used to visualize, e.g., parameter traces across optimizer runs. Parameters ---------- dataset: array of data points to be shown as scatter plot component_labels: labels for the x-axes and the y-axes Returns ------- axs: A dictionary of plot axes. """ # We got more than two components. Create a cross-classification table n_components = dataset.shape[1] axs = _create_crosstab_axes(n_components) # wo don't even try to plot this into an existing axes object. # Overplotting a multi-axes figure is asking for trouble... if 'ax' in kwargs.keys(): del kwargs['ax'] for x_comp in range(0, n_components - 1): for y_comp in range(x_comp + 1, n_components): # handle axis labels x_label = '' y_label = '' if x_comp == 0: y_label = component_labels[y_comp] if y_comp == n_components - 1: x_label = component_labels[x_comp] # extract the wanted columns tmp_dataset = dataset[:, [x_comp, y_comp]] # call lowlevel routine ensemble_scatter_lowlevel( tmp_dataset, x_label=x_label, y_label=y_label, ax=axs[(x_comp, y_comp)], **kwargs, ) # return dict of axes return axs
[docs]def ensemble_scatter_lowlevel( dataset, ax: Optional[plt.Axes] = None, size: Optional[Tuple[float]] = (12, 6), x_label: str = 'component 1', y_label: str = 'component 2', color_by: Sequence[float] = None, color_map: str = 'viridis', background_color: RGBA = (0.0, 0.0, 0.0, 1.0), marker_type: str = '.', scatter_size: float = 0.5, invert_scatter_order: bool = False, ): """ Create a scatter plot. Parameters ---------- dataset: array of data points in reduced dimension ax: Axes object to use. size: Figure size (width, height) in inches. Is only applied when no ax object is specified x_label: The x-axis label y_label: The y-axis label color_by: A sequence/list of floats, which specify the color in the colormap color_map: A colormap name known to pyplot background_color: Background color of the axes object (defaults to black) marker_type: Type of plotted markers scatter_size: Size of plotted markers invert_scatter_order: Specifies the order of plotting the scatter points, can be important in case of overplotting Returns ------- ax: matplotlib.Axes The plot axes. """ # first get the data to check identifiability # axes if ax is None: fig, ax = plt.subplots() fig.set_size_inches(*size) plt.sca(ax) if color_by is None: color_by = np.array([1.0] * dataset.shape[0]) ordering = 1 if invert_scatter_order: ordering = -1 plt.scatter( dataset[::ordering, 0], dataset[::ordering, 1], c=color_by, cmap=color_map, marker=marker_type, s=scatter_size, ) # beautify ax.set_facecolor(background_color) plt.xlabel(x_label) plt.ylabel(y_label) plt.xticks([]) plt.yticks([]) plt.tight_layout() return ax
def _create_crosstab_axes(n_comp: int): """ Create a figure with cross-classification table of axes. Parameters ---------- n_comp: number of component to be mutually compared Returns ------- axs: A dictionary of plot axes. """ axs = {} # run over x- and y-coordinate for x_comp in range(0, n_comp - 1): for y_comp in range(x_comp + 1, n_comp): i_ax = (y_comp - 1) * (n_comp - 1) + x_comp + 1 axs[(x_comp, y_comp)] = plt.subplot(n_comp - 1, n_comp - 1, i_ax) return axs