Source code for pypesto.startpoint.base

"""Startpoint base classes."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable

import numpy as np

from ..C import FVAL, GRAD
from ..objective import ObjectiveBase

if TYPE_CHECKING:
    import pypesto


[docs] class StartpointMethod(ABC): """Startpoint generation, in particular for multi-start optimization. Abstract base class, specific sampling method needs to be defined in sub-classes. """
[docs] @abstractmethod def __call__( self, n_starts: int, problem: pypesto.problem.Problem, ) -> np.ndarray: """Generate startpoints. Parameters ---------- n_starts: Number of starts. problem: Problem specifying e.g. dimensions, bounds, and guesses. Returns ------- xs: Startpoints, shape (n_starts, n_par). """
[docs] class NoStartpoints(StartpointMethod): """Dummy class generating nan points. Useful if no startpoints needed."""
[docs] def __call__( self, n_starts: int, problem: pypesto.problem.Problem, ) -> np.ndarray: """Generate a (n_starts, dim) nan matrix.""" startpoints = np.full(shape=(n_starts, problem.dim), fill_value=np.nan) return startpoints
[docs] class CheckedStartpoints(StartpointMethod, ABC): """Startpoints checked for function value and/or gradient finiteness."""
[docs] def __init__( self, use_guesses: bool = True, check_fval: bool = False, check_grad: bool = False, ): """Initialize. Parameters ---------- use_guesses: Whether to use guesses provided in the problem. check_fval: Whether to check function values at the startpoint, and resample if not finite. check_grad: Whether to check gradients at the startpoint, and resample if not finite. """ self.use_guesses: bool = use_guesses self.check_fval: bool = check_fval self.check_grad: bool = check_grad
[docs] def __call__( self, n_starts: int, problem: pypesto.problem.Problem, ) -> np.ndarray: """Generate checked startpoints.""" # shape: (n_guesses, dim) x_guesses = problem.x_guesses if not self.use_guesses: x_guesses = np.zeros(shape=(0, problem.dim)) dim = problem.dim lb, ub = problem.lb_init, problem.ub_init # number of required startpoints n_guesses = x_guesses.shape[0] n_required = n_starts - n_guesses if n_required <= 0: return x_guesses[:n_starts, :] # apply startpoint method x_sampled = self.sample(n_starts=n_required, lb=lb, ub=ub) # assemble xs = np.zeros(shape=(n_starts, dim)) xs[0:n_guesses, :] = x_guesses xs[n_guesses:n_starts, :] = x_sampled # check, resample and order startpoints xs = self.check_and_resample( xs=xs, lb=lb, ub=ub, objective=problem.objective ) return xs
[docs] @abstractmethod def sample( self, n_starts: int, lb: np.ndarray, ub: np.ndarray, ) -> np.ndarray: """Actually sample startpoints. While in this implementation, `__call__` handles the checking of guesses and resampling, this method defines the actual sampling. Parameters ---------- n_starts: Number of startpoints to generate. lb: Lower parameter bound. ub: Upper parameter bound. Returns ------- xs: Startpoints, shape (n_starts, n_par). """
[docs] def check_and_resample( self, xs: np.ndarray, lb: np.ndarray, ub: np.ndarray, objective: ObjectiveBase, ) -> np.ndarray: """Check sampled points for fval, grad, and potentially resample ones. Parameters ---------- xs: Startpoints candidates, shape (n_starts, n_par). lb: Lower parameter bound. ub: Upper parameter bound. objective: Objective function, for evaluation. Returns ------- xs: Checked and potentially partially resampled startpoints, shape (n_starts, n_par). """ if not self.check_fval and not self.check_grad: return xs if self.check_fval and not self.check_grad: sensi_orders = (0,) elif not self.check_fval and self.check_grad: sensi_orders = (1,) else: sensi_orders = 0, 1 # track function values for ordering fvals = np.empty(shape=(xs.shape[0],)) # iterate over all startpoint candidates for ix, x in enumerate(xs): # evaluate candidate objective.initialize() ret = objective(x, sensi_orders=sensi_orders, return_dict=True) fvals[ix] = ret.get(FVAL, np.nan) # loop until all requested sensis are finite while True: # discontinue if all requested sensis are finite if (0 not in sensi_orders or np.isfinite(ret[FVAL])) and ( 1 not in sensi_orders or np.isfinite(ret[GRAD]).all() ): break # resample a single point x = self.sample(n_starts=1, lb=lb, ub=ub) # evaluate candidate objective.initialize() ret = objective(x, sensi_orders=sensi_orders, return_dict=True) fvals[ix] = ret.get(FVAL, np.nan) # assign permissible value xs[ix] = x # sort startpoints by function value xs_order = np.argsort(fvals) xs = xs[xs_order, :] return xs
[docs] class FunctionStartpoints(CheckedStartpoints): """Define startpoints via callable. The callable should take the same arguments as the `__call__` method. """
[docs] def __init__( self, function: Callable, use_guesses: bool = True, check_fval: bool = False, check_grad: bool = False, ): """Initialize. Parameters ---------- function: The callable sampling startpoints. use_guesses, check_fval, check_grad: As in CheckedStartpoints. """ super().__init__( use_guesses=use_guesses, check_fval=check_fval, check_grad=check_grad, ) self.function: Callable = function
[docs] def sample( self, n_starts: int, lb: np.ndarray, ub: np.ndarray, ) -> np.ndarray: """Call function.""" return self.function(n_starts=n_starts, lb=lb, ub=ub)
[docs] def to_startpoint_method( maybe_startpoint_method: StartpointMethod | Callable | bool, ) -> StartpointMethod: """Create StartpointMethod instance if possible, otherwise raise. Parameters ---------- maybe_startpoint_method: A StartpointMethod instance, or a Callable as expected by FunctionStartpoints. Returns ------- startpoint_method: A StartpointMethod instance. Raises ------ TypeError if arguments cannot be converted to a StartpointMethod. """ if isinstance(maybe_startpoint_method, StartpointMethod): return maybe_startpoint_method if isinstance(maybe_startpoint_method, Callable): return FunctionStartpoints(maybe_startpoint_method) if maybe_startpoint_method is False: return NoStartpoints() raise TypeError( "Could not parse startpoint method of type " f"{type(maybe_startpoint_method)} to a StartpointMethod.", )