Source code for pypesto.objective.aggregated

import inspect
import warnings
from collections.abc import Sequence
from copy import deepcopy
from typing import Any

import numpy as np

from ..C import (
    FVAL,
    GRAD,
    HESS,
    HESSP,
    RDATAS,
    RES,
    SRES,
    ModeType,
)
from .base import ObjectiveBase, ResultDict


[docs] class AggregatedObjective(ObjectiveBase): """Aggregates multiple objectives into one objective."""
[docs] def __init__( self, objectives: Sequence[ObjectiveBase], x_names: Sequence[str] = None, ): """ Initialize objective. Parameters ---------- objectives: Sequence of pypesto.ObjectiveBase instances x_names: Sequence of names of the (optimized) parameters. (Details see documentation of x_names in :class:`pypesto.ObjectiveBase`) """ # input typechecks if not isinstance(objectives, Sequence): raise TypeError( f"Objectives must be a Sequence, " f"was {type(objectives)}." ) if not all( isinstance(objective, ObjectiveBase) for objective in objectives ): raise TypeError( "Objectives must only contain elements of type" "pypesto.Objective" ) if not objectives: raise ValueError("Length of objectives must be at least one") self._objectives = objectives super().__init__(x_names=x_names)
def __deepcopy__(self, memodict=None): """Create copy of objective.""" other = AggregatedObjective( objectives=[deepcopy(objective) for objective in self._objectives], x_names=deepcopy(self.x_names), ) for key in set(self.__dict__.keys()) - {"_objectives", "x_names"}: other.__dict__[key] = deepcopy(self.__dict__[key]) return other
[docs] def check_mode(self, mode: ModeType) -> bool: """See `ObjectiveBase` documentation.""" return all( objective.check_mode(mode) for objective in self._objectives )
[docs] def check_sensi_orders( self, sensi_orders: tuple[int, ...], mode: ModeType, ) -> bool: """See `ObjectiveBase` documentation.""" return all( objective.check_sensi_orders(sensi_orders, mode) for objective in self._objectives )
[docs] def call_unprocessed( self, x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, kwargs_list: Sequence[dict[str, Any]] = None, return_dict: bool = False, **kwargs, ) -> ResultDict: """ See `ObjectiveBase` for more documentation. Main method to overwrite from the base class. It handles and delegates the actual objective evaluation. Parameters ---------- kwargs_list: Objective-specific keyword arguments, where the dictionaries are ordered by the objectives. """ if kwargs_list is None: kwargs_list = [{}] * len(self._objectives) elif len(kwargs_list) != len(self._objectives): raise ValueError( "The length of `kwargs_list` must match the number of " "objectives you are aggregating." ) for objective_, objective_kwargs in zip(self._objectives, kwargs_list): if ( "return_dict" in inspect.signature(objective_.call_unprocessed).parameters ): objective_kwargs["return_dict"] = return_dict else: warnings.warn( "Please add `return_dict` to the argument list of your " "objective's `call_unprocessed` method. " f"Current objective: `{type(objective_)}`.", DeprecationWarning, stacklevel=1, ) return aggregate_results( [ objective.call_unprocessed( x, sensi_orders, mode, **kwargs, **cur_kwargs, ) for objective, cur_kwargs in zip(self._objectives, kwargs_list) ] )
[docs] def initialize(self): """See `ObjectiveBase` documentation.""" for objective in self._objectives: objective.initialize()
[docs] def get_config(self) -> dict: """Return basic information of the objective configuration.""" info = super().get_config() for n_obj, obj in enumerate(self._objectives): info[f"objective_{n_obj}"] = obj.get_config() return info
def aggregate_results(rvals: Sequence[ResultDict]) -> ResultDict: """ Aggregate the results from the provided ResultDicts into a single one. Parameters ---------- rvals: results to aggregate """ # sum over fval/grad/hess, if available in all rvals result = { key: sum(rval[key] for rval in rvals) for key in [FVAL, GRAD, HESS, HESSP] if all(key in rval for rval in rvals) } # extract rdatas and flatten result[RDATAS] = [] for rval in rvals: if RDATAS in rval: result[RDATAS].extend(rval[RDATAS]) # initialize res and sres if RES in rvals[0]: res = np.asarray(rvals[0][RES]) else: res = None if SRES in rvals[0]: sres = np.asarray(rvals[0][SRES]) else: sres = None # skip iobj=0 after initialization, stack matrices for rval in rvals[1:]: if res is not None: res = np.hstack([res, np.asarray(rval[RES])]) if sres is not None: sres = np.vstack([sres, np.asarray(rval[SRES])]) # fill res, sres into result if res is not None: result[RES] = res if sres is not None: result[SRES] = sres return result