from collections.abc import Sequence
from copy import deepcopy
from typing import Any
import numpy as np
from ..C import FVAL, GRAD, HESS, HESSP, RDATAS, RES, SRES, ModeType
from .base import ObjectiveBase, ResultDict
[docs]
class AggregatedObjective(ObjectiveBase):
"""Aggregates multiple objectives into one objective."""
[docs]
def __init__(
self,
objectives: Sequence[ObjectiveBase],
x_names: Sequence[str] = None,
):
"""
Initialize objective.
Parameters
----------
objectives:
Sequence of pypesto.ObjectiveBase instances
x_names:
Sequence of names of the (optimized) parameters.
(Details see documentation of x_names in
:class:`pypesto.ObjectiveBase`)
"""
# input typechecks
if not isinstance(objectives, Sequence):
raise TypeError(
f"Objectives must be a Sequence, " f"was {type(objectives)}."
)
if not all(
isinstance(objective, ObjectiveBase) for objective in objectives
):
raise TypeError(
"Objectives must only contain elements of type"
"pypesto.Objective"
)
if not objectives:
raise ValueError("Length of objectives must be at least one")
self._objectives = objectives
super().__init__(x_names=x_names)
def __deepcopy__(self, memodict=None):
"""Create copy of objective."""
other = AggregatedObjective(
objectives=[deepcopy(objective) for objective in self._objectives],
x_names=deepcopy(self.x_names),
)
for key in set(self.__dict__.keys()) - {"_objectives", "x_names"}:
other.__dict__[key] = deepcopy(self.__dict__[key])
return other
[docs]
def check_mode(self, mode: ModeType) -> bool:
"""See `ObjectiveBase` documentation."""
return all(
objective.check_mode(mode) for objective in self._objectives
)
[docs]
def check_sensi_orders(
self,
sensi_orders: tuple[int, ...],
mode: ModeType,
) -> bool:
"""See `ObjectiveBase` documentation."""
return all(
objective.check_sensi_orders(sensi_orders, mode)
for objective in self._objectives
)
[docs]
def call_unprocessed(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
kwargs_list: Sequence[dict[str, Any]] = None,
**kwargs,
) -> ResultDict:
"""
See `ObjectiveBase` for more documentation.
Main method to overwrite from the base class. It handles and
delegates the actual objective evaluation.
Parameters
----------
kwargs_list:
Objective-specific keyword arguments, where the dictionaries are
ordered by the objectives.
"""
if kwargs_list is None:
kwargs_list = [{}] * len(self._objectives)
elif len(kwargs_list) != len(self._objectives):
raise ValueError(
"The length of `kwargs_list` must match the number of "
"objectives you are aggregating."
)
return aggregate_results(
[
objective.call_unprocessed(
x,
sensi_orders,
mode,
**kwargs,
**cur_kwargs,
)
for objective, cur_kwargs in zip(self._objectives, kwargs_list)
]
)
[docs]
def initialize(self):
"""See `ObjectiveBase` documentation."""
for objective in self._objectives:
objective.initialize()
[docs]
def get_config(self) -> dict:
"""Return basic information of the objective configuration."""
info = super().get_config()
for n_obj, obj in enumerate(self._objectives):
info[f"objective_{n_obj}"] = obj.get_config()
return info
def aggregate_results(rvals: Sequence[ResultDict]) -> ResultDict:
"""
Aggregate the results from the provided ResultDicts into a single one.
Parameters
----------
rvals:
results to aggregate
"""
# sum over fval/grad/hess, if available in all rvals
result = {
key: sum(rval[key] for rval in rvals)
for key in [FVAL, GRAD, HESS, HESSP]
if all(key in rval for rval in rvals)
}
# extract rdatas and flatten
result[RDATAS] = []
for rval in rvals:
if RDATAS in rval:
result[RDATAS].extend(rval[RDATAS])
# initialize res and sres
if RES in rvals[0]:
res = np.asarray(rvals[0][RES])
else:
res = None
if SRES in rvals[0]:
sres = np.asarray(rvals[0][SRES])
else:
sres = None
# skip iobj=0 after initialization, stack matrices
for rval in rvals[1:]:
if res is not None:
res = np.hstack([res, np.asarray(rval[RES])])
if sres is not None:
sres = np.vstack([sres, np.asarray(rval[SRES])])
# fill res, sres into result
if res is not None:
result[RES] = res
if sres is not None:
result[SRES] = sres
return result