Source code for pypesto.history.memory

"""In-memory history."""

import time
from collections.abc import Sequence
from typing import Any, Union

import numpy as np

from ..C import FVAL, GRAD, HESS, RES, SRES, TIME, ModeType, X
from .base import (
    CountHistoryBase,
    HistoryBase,
    add_fun_from_res,
    reduce_result_via_options,
)
from .options import HistoryOptions
from .util import MaybeArray, ResultDict, trace_wrap


[docs] class MemoryHistory(CountHistoryBase): """ Class for optimization history stored in memory. Tracks number of function evaluations and keeps an in-memory trace of function evaluations. Parameters ---------- options: History options, see :class:`pypesto.history.HistoryOptions`. Defaults to `None`, which implies default options. """
[docs] def __init__(self, options: Union[HistoryOptions, dict, None] = None): super().__init__(options=options) self._trace: dict[str, Any] = {key: [] for key in HistoryBase.ALL_KEYS}
[docs] def update( self, x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, result: ResultDict, ) -> None: """See :meth:`HistoryBase.update`.""" super().update(x, sensi_orders, mode, result) self._update_trace(x, mode, result)
def _update_trace(self, x, mode, result): """Update internal trace representation.""" # calculating function values from residuals # and reduce via requested history options result: dict = reduce_result_via_options( add_fun_from_res(result), self.options ) result[X] = x used_time = time.time() - self._start_time result[TIME] = used_time for key in HistoryBase.ALL_KEYS: self._trace[key].append(result[key]) def __len__(self) -> int: """Define length of history object.""" return len(self._trace[TIME])
[docs] @trace_wrap def get_x_trace( self, ix: Union[int, Sequence[int], None] = None, trim: bool = False, ) -> Union[Sequence[np.ndarray], np.ndarray]: """See :meth:`HistoryBase.get_x_trace`.""" return [self._trace[X][i] for i in ix]
[docs] @trace_wrap def get_fval_trace( self, ix: Union[int, Sequence[int], None] = None, trim: bool = False, ) -> Union[Sequence[float], float]: """See :meth:`HistoryBase.get_fval_trace`.""" return [self._trace[FVAL][i] for i in ix]
[docs] @trace_wrap def get_grad_trace( self, ix: Union[int, Sequence[int], None] = None, trim: bool = False, ) -> Union[Sequence[MaybeArray], MaybeArray]: """See :meth:`HistoryBase.get_grad_trace`.""" return [self._trace[GRAD][i] for i in ix]
[docs] @trace_wrap def get_hess_trace( self, ix: Union[int, Sequence[int], None] = None, trim: bool = False, ) -> Union[Sequence[MaybeArray], MaybeArray]: """See :meth:`HistoryBase.get_hess_trace`.""" return [self._trace[HESS][i] for i in ix]
[docs] @trace_wrap def get_res_trace( self, ix: Union[int, Sequence[int], None] = None, trim: bool = False, ) -> Union[Sequence[MaybeArray], MaybeArray]: """See :meth:`HistoryBase.get_res_trace`.""" return [self._trace[RES][i] for i in ix]
[docs] @trace_wrap def get_sres_trace( self, ix: Union[int, Sequence[int], None] = None, trim: bool = False, ) -> Union[Sequence[MaybeArray], MaybeArray]: """See :meth:`HistoryBase.get_sres_trace`.""" return [self._trace[SRES][i] for i in ix]
[docs] @trace_wrap def get_time_trace( self, ix: Union[int, Sequence[int], None] = None, trim: bool = False, ) -> Union[Sequence[float], float]: """See :meth:`HistoryBase.get_time_trace`.""" return [self._trace[TIME][i] for i in ix]