Source code for pypesto.optimize.ess.cess

"""Cooperative Enhanced Scatter Search."""
import logging
import multiprocessing
import os
import time
from typing import Dict, List, Optional
from warnings import warn

import numpy as np

import pypesto.optimize
from pypesto import Problem
from pypesto.startpoint import StartpointMethod

from .ess import ESSExitFlag, ESSOptimizer
from .function_evaluator import FunctionEvaluator
from .refset import RefSet

logger = logging.getLogger(__name__)

[docs] class CESSOptimizer: r""" Cooperative Enhanced Scatter Search Optimizer (CESS). A cooperative scatter search algorithm based on [VillaverdeEge2012]_. In short, multiple scatter search instances with different hyperparameters are running in different threads/processes, and exchange information. Some instances focus on diversification while others focus on intensification. Communication happens at fixed time intervals. Proposed hyperparameter values in [VillaverdeEge2012]_: * ``dim_refset``: ``[0.5 n_parameter, 20 n_parameters]`` * ``local_n2``: ``[0, 100]`` * ``balance``: ``[0, 0.5]`` * ``n_diverse``: ``[5 n_par, 20 n_par]`` * ``max_eval``: such that :math:`\tau = log10(max_eval / n_par)` is in [2.5, 3.5], with a recommended default value of 2.5. .. [VillaverdeEge2012] 'A cooperative strategy for parameter estimation in large scale systems biology models', Villaverde, A.F., Egea, J.A. & Banga, J.R. BMC Syst Biol 2012, 6, 75. Attributes ---------- ess_init_args: List of argument dictionaries passed to :func:`ESSOptimizer.__init__`. The length of this list is the number of parallel ESS processes. Resource limits such as ``max_eval`` apply to a single CESS iteration, not to the full search. max_iter: Maximum number of CESS iterations. max_walltime_s: Maximum walltime in seconds. Will only be checked between local optimizations and other simulations, and thus, may be exceeded by the duration of a local search. Defaults to no limit. fx_best: The best objective value seen so far. x_best: Parameter vector corresponding to ``fx_best``. starttime: Starting time of the most recent optimization. i_iter: Current iteration number. """
[docs] def __init__( self, ess_init_args: List[Dict], max_iter: int, max_walltime_s: float = np.inf, ): """Construct. Parameters ---------- ess_init_args: List of argument dictionaries passed to :func:`ESSOptimizer.__init__`. The length of this list is the number of parallel ESS processes. Resource limits such as ``max_eval`` apply to a single CESS iteration, not to the full search. max_iter: Maximum number of CESS iterations. max_walltime_s: Maximum walltime in seconds. Will only be checked between local optimizations and other simulations, and thus, may be exceeded by the duration of a local search. Defaults to no limit. """ self.max_walltime_s = max_walltime_s self.ess_init_args = ess_init_args self.max_iter = max_iter self._initialize()
def _initialize(self): """(Re-)initialize.""" self.starttime = time.time() self.i_iter = 0 # Overall best parameters found so far self.x_best: Optional[np.array] = None # Overall best function value found so far self.fx_best: float = np.inf
[docs] def minimize( self, problem: Problem, startpoint_method: StartpointMethod = None, ) -> pypesto.Result: """Minimize the given objective using CESS. Parameters ---------- problem: Problem to run ESS on. startpoint_method: Method for choosing starting points. **Deprecated. Use ``problem.startpoint_method`` instead.** """ if startpoint_method is not None: warn( "Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.", DeprecationWarning, ) self._initialize() evaluator = FunctionEvaluator( problem=problem, startpoint_method=startpoint_method, ) refsets = [ RefSet(evaluator=evaluator, dim=ess_init_args['dim_refset']) for ess_init_args in self.ess_init_args ] for refset, ess_init_args in zip(refsets, self.ess_init_args): refset.initialize_random( n_diverse=ess_init_args.get('n_diverse', 10 * problem.dim) ) while True:"-" * 50) self._report_iteration() self.i_iter += 1 # run scatter searches ess_optimizers = self._run_scatter_searches( refsets=refsets, ) # collect refsets from the different ESS runs refsets = [result.refset for result in ess_optimizers] # update best values from ESS results for result in ess_optimizers: self._maybe_update_global_best(result.x_best, result.fx_best) if not self._keep_going(i_eval=evaluator.n_eval): break # create refsets for the next iteration self._update_refsets(refsets=refsets, evaluator=evaluator) # TODO merge results self._report_final() # TODO what should the result look like? return self._create_result(problem, refsets)
def _report_iteration(self): """Log the current iteration.""" if self.max_iter == 0:"iter | best |") with np.printoptions( edgeitems=30, linewidth=100000, formatter={"float": lambda x: "%.3g" % x}, ):"{self.i_iter:4} | {self.fx_best:+.2E} | ") def _report_final(self): """Log scatter search summary.""" with np.printoptions( edgeitems=30, linewidth=100000, formatter={"float": lambda x: "%.3g" % x}, ): f"CESS finished with {self.exit_flag!r} " f"after {self.i_iter} iterations, " f"{time.time() - self.starttime:.3g}s. " # f"Num local solutions: {len(self.local_solutions)}." ) #"Final refset: {np.sort(self.refset.fx)} ")"Best fval {self.fx_best}") def _create_result( self, problem: pypesto.Problem, refsets: List[RefSet] ) -> pypesto.Result: """Create the result object. Currently, this returns the overall best value and the final RefSet. """ common_result_fields = { 'exitflag': self.exit_flag, # meaningful? this is the overall time, and identical for all # reported points 'time': time.time() - self.starttime, # TODO # 'n_fval': self.evaluator.n_eval, 'optimizer': str(self), } i_result = 0 result = pypesto.Result(problem=problem) # save global best optimizer_result = pypesto.OptimizerResult( id=str(i_result), x=self.x_best, fval=self.fx_best, message="Global best", **common_result_fields, ) # TODO DW: Create a single History with the global best? result.optimize_result.append(optimizer_result) # save refsets for i_refset, refset in enumerate(refsets): for i in range(refset.dim): i_result += 1 result.optimize_result.append( pypesto.OptimizerResult( id=str(i_result), x=refset.x[i], fval=refset.fx[i], message=f"RefSet[{i_refset}][{i}]", **common_result_fields, ) ) # TODO DW: also save local solutions? # (need to track fvals or re-evaluate) return result def _run_scatter_searches( self, refsets: List[RefSet], ) -> List[ESSOptimizer]: """Start all scatter searches in different processes.""" # set default value of max_eval if not present. # only set it on a copy, as the original dicts may be re-used # for different optimization problems. # reasonable value proposed in [VillaverdeEge2012]: # 2.5 < tau < 3.5, default: 2.5 ess_init_args = [ dict( { 'max_eval': int( 10**2.5 * refsets[0].evaluator.problem.dim ) }, **ess_kwargs, ) for ess_kwargs in self.ess_init_args ] ctx = multiprocessing.get_context('spawn') with ctx.Pool(len(self.ess_init_args)) as pool: ess_optimizers = pool.starmap( self._run_single_ess, ( [ess_kwargs, refset] for (ess_kwargs, refset) in zip(ess_init_args, refsets) ), chunksize=1, ) return list(ess_optimizers) def _run_single_ess( self, ess_kwargs, refset: RefSet, ) -> ESSOptimizer: """ Run ESS. Helper for `starmap`. """ # different random seeds per process? np.random.seed((os.getpid() * int(time.time() * 1000)) % 2**32) ess = ESSOptimizer(**ess_kwargs) ess.minimize(refset=refset) return ess def _keep_going(self, i_eval) -> bool: """Check exit criteria. Returns ------- ``True`` if none of the exit criteria is met, ``False`` otherwise. """ # TODO DW which further stopping criteria: gtol, fatol, frtol? # elapsed iterations if self.i_iter >= self.max_iter: self.exit_flag = ESSExitFlag.MAX_ITER return False # elapsed time if time.time() - self.starttime >= self.max_walltime_s: self.exit_flag = ESSExitFlag.MAX_TIME return False return True def _maybe_update_global_best(self, x, fx): """Update the global best value if the provided value is better.""" if fx < self.fx_best: self.x_best = x[:] self.fx_best = fx def _update_refsets( self, refsets: List[RefSet], evaluator: FunctionEvaluator ): """ Update refsets. Create new refsets based on the combined final refsets of the previous CESS iteration. Updates ``refsets`` in place. """ # gather final refset entries x = np.vstack([refset.x for refset in refsets]) fx = np.concatenate([refset.fx for refset in refsets]) # reset function evaluation counter evaluator.n_eval = 0 evaluator.n_eval_round = 0 for i, ess_init_args in enumerate(self.ess_init_args): refsets[i] = RefSet( dim=ess_init_args['dim_refset'], evaluator=evaluator ) refsets[i].initialize_from_array(x_diverse=x, fx_diverse=fx) refsets[i].sort()