"""HDF5 history."""
import contextlib
import time
from typing import Dict, Sequence, Tuple, Union
import h5py
import numpy as np
from ..C import (
EXITFLAG,
FVAL,
GRAD,
HESS,
HISTORY,
MESSAGE,
MESSAGES,
MODE_FUN,
MODE_RES,
N_FVAL,
N_GRAD,
N_HESS,
N_ITERATIONS,
N_RES,
N_SRES,
RES,
SRES,
START_TIME,
TIME,
TRACE,
TRACE_SAVE_ITER,
ModeType,
X,
)
from .base import HistoryBase, add_fun_from_res, reduce_result_via_options
from .options import HistoryOptions
from .util import MaybeArray, ResultDict, trace_wrap
def with_h5_file(mode: str):
"""Wrap function to work with hdf5 file.
Parameters
----------
mode:
Access mode, see
https://docs.h5py.org/en/stable/high/file.html.
"""
modes = ["r", "a"]
if mode not in modes:
# can be extended if reasonable
raise ValueError(f"Mode must be one of {modes}")
def decorator(fun):
def wrapper(self, *args, **kwargs):
# file already opened
if self._f is not None and (
mode == self._f.mode
or mode == "r"
or (self._f.mode == "r+" and mode == "a")
):
return fun(self, *args, **kwargs)
with h5py.File(self.file, mode) as f:
self._f = f
ret = fun(self, *args, **kwargs)
self._f = None
return ret
return wrapper
return decorator
def check_editable(fun):
"""Check if the history is editable."""
def wrapper(self, *args, **kwargs):
if not self.editable:
raise ValueError(
f'ID "{self.id}" is already used in history file '
f'"{self.file}".'
)
return fun(self, *args, **kwargs)
return wrapper
[docs]class Hdf5History(HistoryBase):
"""
Stores a representation of the history in an HDF5 file.
Parameters
----------
id:
Id of the history
file:
HDF5 file name.
options:
History options.
"""
[docs] def __init__(
self,
id: str,
file: str,
options: Union[HistoryOptions, Dict] = None,
):
super().__init__(options=options)
self.id: str = id
self.file: str = file
# filled during file access
self._f: Union[h5py.File, None] = None
# to check whether the trace can be edited
self.editable: bool = self._editable()
[docs] @check_editable
@with_h5_file("a")
def update(
self,
x: np.ndarray,
sensi_orders: Tuple[int, ...],
mode: ModeType,
result: ResultDict,
) -> None:
"""See `History` docstring."""
# check whether the file was marked as editable upon initialization
super().update(x, sensi_orders, mode, result)
self._update_counts(sensi_orders, mode)
self._update_trace(x, sensi_orders, mode, result)
[docs] @with_h5_file("a")
@check_editable
def finalize(self, message: str = None, exitflag: str = None) -> None:
"""See `HistoryBase` docstring."""
super().finalize()
# add message and exitflag to trace
f = self._f
if f'{HISTORY}/{self.id}/{MESSAGES}/' not in f:
f.create_group(f'{HISTORY}/{self.id}/{MESSAGES}/')
grp = f[f'{HISTORY}/{self.id}/{MESSAGES}/']
if message is not None:
grp.attrs[MESSAGE] = message
if exitflag is not None:
grp.attrs[EXITFLAG] = exitflag
[docs] @staticmethod
def load(
id: str, file: str, options: Union[HistoryOptions, Dict] = None
) -> 'Hdf5History':
"""Load the History object from memory."""
history = Hdf5History(id=id, file=file, options=options)
if options is None:
history.recover_options(file)
return history
[docs] def recover_options(self, file: str):
"""Recover options when loading the hdf5 history from memory.
Done by testing which entries were recorded.
"""
trace_record = self._has_non_nan_entries(X)
trace_record_grad = self._has_non_nan_entries(GRAD)
trace_record_hess = self._has_non_nan_entries(HESS)
trace_record_res = self._has_non_nan_entries(RES)
trace_record_sres = self._has_non_nan_entries(SRES)
restored_history_options = HistoryOptions(
trace_record=trace_record,
trace_record_grad=trace_record_grad,
trace_record_hess=trace_record_hess,
trace_record_res=trace_record_res,
trace_record_sres=trace_record_sres,
trace_save_iter=self.trace_save_iter,
storage_file=file,
)
self.options = restored_history_options
def _has_non_nan_entries(self, hdf5_group: str) -> bool:
"""Check if there exist non-nan entries stored for a given group."""
group = self._get_hdf5_entries(hdf5_group, ix=None)
for entry in group:
if not (entry is None or np.all(np.isnan(entry))):
return True
return False
@with_h5_file("a")
def _update_counts(self, sensi_orders: Tuple[int, ...], mode: ModeType):
"""Update the counters in the hdf5."""
group = self._require_group()
if mode == MODE_FUN:
if 0 in sensi_orders:
group.attrs[N_FVAL] += 1
if 1 in sensi_orders:
group.attrs[N_GRAD] += 1
if 2 in sensi_orders:
group.attrs[N_HESS] += 1
elif mode == MODE_RES:
if 0 in sensi_orders:
group.attrs[N_RES] += 1
if 1 in sensi_orders:
group.attrs[N_SRES] += 1
@with_h5_file("r")
def __len__(self) -> int:
"""Define length of history object."""
try:
return self._get_group().attrs[N_ITERATIONS]
except KeyError:
return 0
@property
@with_h5_file("r")
def n_fval(self) -> int:
"""See `HistoryBase` docstring."""
try:
return self._get_group().attrs[N_FVAL]
except KeyError:
return 0
@property
@with_h5_file("r")
def n_grad(self) -> int:
"""See `HistoryBase` docstring."""
try:
return self._get_group().attrs[N_GRAD]
except KeyError:
return 0
@property
@with_h5_file("r")
def n_hess(self) -> int:
"""See `HistoryBase` docstring."""
try:
return self._get_group().attrs[N_HESS]
except KeyError:
return 0
@property
@with_h5_file("r")
def n_res(self) -> int:
"""See `HistoryBase` docstring."""
try:
return self._get_group().attrs[N_RES]
except KeyError:
return 0
@property
@with_h5_file("r")
def n_sres(self) -> int:
"""See `HistoryBase` docstring."""
try:
return self._get_group().attrs[N_SRES]
except KeyError:
return 0
@property
@with_h5_file("r")
def trace_save_iter(self) -> int:
"""After how many iterations to store the trace."""
try:
return self._get_group().attrs[TRACE_SAVE_ITER]
except KeyError:
return 0
@property
@with_h5_file("r")
def start_time(self) -> float:
"""See `HistoryBase` docstring."""
# TODO Y This should also be saved in and recovered from the hdf5 file
try:
return self._get_group().attrs[START_TIME]
except KeyError:
return np.nan
@property
@with_h5_file("r")
def message(self) -> str:
"""Optimizer message in case of finished optimization."""
try:
return self._f[f'{HISTORY}/{self.id}/{MESSAGES}/'].attrs[MESSAGE]
except KeyError:
return None
@property
@with_h5_file("r")
def exitflag(self) -> str:
"""Optimizer exitflag in case of finished optimization."""
try:
return self._f[f'{HISTORY}/{self.id}/{MESSAGES}/'].attrs[EXITFLAG]
except KeyError:
return None
@with_h5_file("a")
def _update_trace(
self,
x: np.ndarray,
sensi_orders: Tuple[int],
mode: ModeType,
result: ResultDict,
) -> None:
"""Update and possibly store the trace."""
if not self.options.trace_record:
return
# calculating function values from residuals
# and reduce via requested history options
result = reduce_result_via_options(
add_fun_from_res(result), self.options
)
used_time = time.time() - self.start_time
values = {
X: x,
FVAL: result[FVAL],
GRAD: result[GRAD],
RES: result[RES],
SRES: result[SRES],
HESS: result[HESS],
TIME: used_time,
}
iteration = self._require_group().attrs[N_ITERATIONS]
for key in values.keys():
if values[key] is not None:
self._require_group()[f'{iteration}/{key}'] = values[key]
self._require_group().attrs[N_ITERATIONS] += 1
@with_h5_file("r")
def _get_group(self) -> h5py.Group:
"""Get the HDF5 group for the current history."""
return self._f[f'{HISTORY}/{self.id}/{TRACE}/']
@with_h5_file("a")
def _require_group(self) -> h5py.Group:
"""Get, or if necessary create, the group in the hdf5 file."""
with contextlib.suppress(KeyError):
return self._f[f'{HISTORY}/{self.id}/{TRACE}/']
grp = self._f.create_group(f'{HISTORY}/{self.id}/{TRACE}/')
grp.attrs[N_ITERATIONS] = 0
grp.attrs[N_FVAL] = 0
grp.attrs[N_GRAD] = 0
grp.attrs[N_HESS] = 0
grp.attrs[N_RES] = 0
grp.attrs[N_SRES] = 0
grp.attrs[START_TIME] = time.time()
# TODO Y it makes no sense to save this here
# Also, we do not seem to evaluate this at all
grp.attrs[TRACE_SAVE_ITER] = self.options.trace_save_iter
return grp
@with_h5_file("r")
def _get_hdf5_entries(
self,
entry_id: str,
ix: Union[int, Sequence[int], None] = None,
) -> Sequence:
"""
Get entries for field `entry_id` from HDF5 file, for indices `ix`.
Parameters
----------
entry_id:
The key whose trace is returned.
ix:
Index or list of indices of the iterations that will produce
the trace.
Returns
-------
The entries ix for the key entry_id.
"""
if ix is None:
ix = range(len(self))
trace_result = []
for iteration in ix:
try:
dataset = self._f[
f'{HISTORY}/{self.id}/{TRACE}/{iteration}/{entry_id}'
]
if dataset.shape == ():
entry = dataset[()] # scalar
else:
entry = np.array(dataset)
trace_result.append(entry)
except KeyError:
trace_result.append(None)
return trace_result
@trace_wrap
def get_x_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[np.ndarray], np.ndarray]:
"""See `HistoryBase` docstring."""
return self._get_hdf5_entries(X, ix)
@trace_wrap
def get_fval_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[float], float]:
"""See `HistoryBase` docstring."""
return self._get_hdf5_entries(FVAL, ix)
@trace_wrap
def get_grad_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[MaybeArray], MaybeArray]:
"""See `HistoryBase` docstring."""
return self._get_hdf5_entries(GRAD, ix)
@trace_wrap
def get_hess_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[MaybeArray], MaybeArray]:
"""See `HistoryBase` docstring."""
return self._get_hdf5_entries(HESS, ix)
@trace_wrap
def get_res_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[MaybeArray], MaybeArray]:
"""See `HistoryBase` docstring."""
return self._get_hdf5_entries(RES, ix)
@trace_wrap
def get_sres_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[MaybeArray], MaybeArray]:
"""See `HistoryBase` docstring."""
return self._get_hdf5_entries(SRES, ix)
@trace_wrap
def get_time_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[float], float]:
"""See `HistoryBase` docstring."""
return self._get_hdf5_entries(TIME, ix)
def _editable(self) -> bool:
"""
Check whether the id is already existent in the file.
Parameters
----------
file:
HDF5 file name.
Returns
-------
True if the file is editable, False otherwise.
"""
try:
with h5py.File(self.file, "a") as f:
# editable if the id entry does not exist
if HISTORY not in f.keys() or self.id not in f[HISTORY]:
return True
return False
except OSError:
# if something goes wrong, we assume the file is not editable
return False