"""HDF5 history."""
import contextlib
import time
from collections.abc import Sequence
from functools import wraps
from pathlib import Path
from typing import Union
import h5py
import numpy as np
from ..C import (
EXITFLAG,
FVAL,
GRAD,
HESS,
HISTORY,
MESSAGE,
MESSAGES,
MODE_FUN,
MODE_RES,
N_FVAL,
N_GRAD,
N_HESS,
N_ITERATIONS,
N_RES,
N_SRES,
RES,
SRES,
START_TIME,
TIME,
TRACE,
TRACE_SAVE_ITER,
ModeType,
X,
)
from .base import HistoryBase, add_fun_from_res, reduce_result_via_options
from .options import HistoryOptions
from .util import MaybeArray, ResultDict, trace_wrap
def with_h5_file(mode: str):
"""Wrap function to work with hdf5 file.
Parameters
----------
mode:
Access mode, see
https://docs.h5py.org/en/stable/high/file.html.
"""
modes = ["r", "a"]
if mode not in modes:
# can be extended if reasonable
raise ValueError(f"Mode must be one of {modes}")
def decorator(fun):
@wraps(fun)
def wrapper(self, *args, **kwargs):
# file already opened
if self._f is not None and (
mode == self._f.mode
or mode == "r"
or (self._f.mode == "r+" and mode == "a")
):
return fun(self, *args, **kwargs)
with h5py.File(self.file, mode) as f:
self._f = f
ret = fun(self, *args, **kwargs)
self._f = None
return ret
return wrapper
return decorator
def check_editable(fun):
"""Warp function to check whether the history is editable."""
@wraps(fun)
def wrapper(self, *args, **kwargs):
if not self.editable:
raise ValueError(
f'ID "{self.id}" is already used in history file '
f'"{self.file}".'
)
return fun(self, *args, **kwargs)
return wrapper
[docs]
class Hdf5History(HistoryBase):
"""
Stores a representation of the history in an HDF5 file.
Parameters
----------
id:
Id of the history
file:
HDF5 file name.
options:
History options. Defaults to ``None``.
"""
[docs]
def __init__(
self,
id: str,
file: Union[str, Path],
options: Union[HistoryOptions, dict, None] = None,
):
super().__init__(options=options)
self.id: str = id
self.file: str = str(file)
# filled during file access
self._f: Union[h5py.File, None] = None
# to check whether the trace can be edited
self.editable: bool = self._editable()
[docs]
@check_editable
@with_h5_file("a")
def update(
self,
x: np.ndarray,
sensi_orders: tuple[int, ...],
mode: ModeType,
result: ResultDict,
) -> None:
"""See :meth:`HistoryBase.update`."""
# check whether the file was marked as editable upon initialization
super().update(x, sensi_orders, mode, result)
self._update_counts(sensi_orders, mode)
self._update_trace(x, sensi_orders, mode, result)
[docs]
@with_h5_file("a")
@check_editable
def finalize(self, message: str = None, exitflag: str = None) -> None:
"""See :class:`HistoryBase.finalize`."""
super().finalize()
# add message and exitflag to trace
grp = self._f.require_group(f"{HISTORY}/{self.id}/{MESSAGES}/")
if message is not None:
grp.attrs[MESSAGE] = message
if exitflag is not None:
grp.attrs[EXITFLAG] = exitflag
[docs]
@staticmethod
def load(
id: str,
file: Union[str, Path],
options: Union[HistoryOptions, dict] = None,
) -> "Hdf5History":
"""Load the History object from memory."""
history = Hdf5History(id=id, file=file, options=options)
if options is None:
history.recover_options(file)
return history
[docs]
def recover_options(self, file: Union[str, Path]):
"""Recover options when loading the hdf5 history from memory.
Done by testing which entries were recorded.
"""
trace_record = self._has_non_nan_entries(X)
trace_record_grad = self._has_non_nan_entries(GRAD)
trace_record_hess = self._has_non_nan_entries(HESS)
trace_record_res = self._has_non_nan_entries(RES)
trace_record_sres = self._has_non_nan_entries(SRES)
restored_history_options = HistoryOptions(
trace_record=trace_record,
trace_record_grad=trace_record_grad,
trace_record_hess=trace_record_hess,
trace_record_res=trace_record_res,
trace_record_sres=trace_record_sres,
trace_save_iter=self.trace_save_iter,
storage_file=file,
)
self.options = restored_history_options
def _has_non_nan_entries(self, hdf5_group: str) -> bool:
"""Check if there exist non-nan entries stored for a given group."""
group = self._get_hdf5_entries(hdf5_group, ix=None)
for entry in group:
if not (entry is None or np.all(np.isnan(entry))):
return True
return False
@with_h5_file("a")
def _update_counts(self, sensi_orders: tuple[int, ...], mode: ModeType):
"""Update the counters in the hdf5 file."""
group = self._require_group()
if mode == MODE_FUN:
if 0 in sensi_orders:
group.attrs[N_FVAL] += 1
if 1 in sensi_orders:
group.attrs[N_GRAD] += 1
if 2 in sensi_orders:
group.attrs[N_HESS] += 1
elif mode == MODE_RES:
if 0 in sensi_orders:
group.attrs[N_RES] += 1
if 1 in sensi_orders:
group.attrs[N_SRES] += 1
@with_h5_file("r")
def __len__(self) -> int:
"""Define length of history object."""
try:
return self._get_group().attrs[N_ITERATIONS]
except KeyError:
return 0
@property
@with_h5_file("r")
def n_fval(self) -> int:
"""See :meth:`HistoryBase.n_fval`."""
try:
return self._get_group().attrs[N_FVAL]
except KeyError:
return 0
@property
@with_h5_file("r")
def n_grad(self) -> int:
"""See :meth:`HistoryBase.n_grad`."""
try:
return self._get_group().attrs[N_GRAD]
except KeyError:
return 0
@property
@with_h5_file("r")
def n_hess(self) -> int:
"""See :meth:`HistoryBase.n_hess`."""
try:
return self._get_group().attrs[N_HESS]
except KeyError:
return 0
@property
@with_h5_file("r")
def n_res(self) -> int:
"""See :meth:`HistoryBase.n_res`."""
try:
return self._get_group().attrs[N_RES]
except KeyError:
return 0
@property
@with_h5_file("r")
def n_sres(self) -> int:
"""See :meth:`HistoryBase.n_sres`."""
try:
return self._get_group().attrs[N_SRES]
except KeyError:
return 0
@property
@with_h5_file("r")
def trace_save_iter(self) -> int:
"""After how many iterations to store the trace."""
try:
return self._get_group().attrs[TRACE_SAVE_ITER]
except KeyError:
return 0
@property
@with_h5_file("r")
def start_time(self) -> float:
"""See :meth:`HistoryBase.start_time`."""
# TODO Y This should also be saved in and recovered from the hdf5 file
try:
return self._get_group().attrs[START_TIME]
except KeyError:
return np.nan
@property
@with_h5_file("r")
def message(self) -> str:
"""Optimizer message in case of finished optimization."""
try:
return self._f[f"{HISTORY}/{self.id}/{MESSAGES}/"].attrs[MESSAGE]
except KeyError:
return None
@property
@with_h5_file("r")
def exitflag(self) -> str:
"""Optimizer exitflag in case of finished optimization."""
try:
return self._f[f"{HISTORY}/{self.id}/{MESSAGES}/"].attrs[EXITFLAG]
except KeyError:
return None
@staticmethod
def _simulation_to_values(x, result, used_time):
values = {
X: x,
FVAL: result[FVAL],
GRAD: result[GRAD],
RES: result[RES],
SRES: result[SRES],
HESS: result[HESS],
TIME: used_time,
}
return values
@with_h5_file("a")
def _update_trace(
self,
x: np.ndarray,
sensi_orders: tuple[int],
mode: ModeType,
result: ResultDict,
) -> None:
"""Update and possibly store the trace."""
if not self.options.trace_record:
return
# calculating function values from residuals
# and reduce via requested history options
result = reduce_result_via_options(
add_fun_from_res(result), self.options
)
used_time = time.time() - self.start_time
values = self._simulation_to_values(x, result, used_time)
iteration = self._require_group().attrs[N_ITERATIONS]
for key in values.keys():
if values[key] is not None:
self._require_group()[f"{iteration}/{key}"] = values[key]
self._require_group().attrs[N_ITERATIONS] += 1
@with_h5_file("r")
def _get_group(self) -> h5py.Group:
"""Get the HDF5 group for the current history."""
return self._f[f"{HISTORY}/{self.id}/{TRACE}/"]
@with_h5_file("a")
def _require_group(self) -> h5py.Group:
"""Get, or if necessary create, the group in the hdf5 file."""
with contextlib.suppress(KeyError):
return self._f[f"{HISTORY}/{self.id}/{TRACE}/"]
grp = self._f.create_group(f"{HISTORY}/{self.id}/{TRACE}/")
grp.attrs[N_ITERATIONS] = 0
grp.attrs[N_FVAL] = 0
grp.attrs[N_GRAD] = 0
grp.attrs[N_HESS] = 0
grp.attrs[N_RES] = 0
grp.attrs[N_SRES] = 0
grp.attrs[START_TIME] = time.time()
# TODO Y it makes no sense to save this here
# Also, we do not seem to evaluate this at all
grp.attrs[TRACE_SAVE_ITER] = self.options.trace_save_iter
return grp
@with_h5_file("r")
def _get_hdf5_entries(
self,
entry_id: str,
ix: Union[int, Sequence[int], None] = None,
) -> Sequence:
"""
Get entries for field `entry_id` from HDF5 file, for indices `ix`.
Parameters
----------
entry_id:
The key whose trace is returned.
ix:
Index or list of indices of the iterations that will produce
the trace. Defaults to ``None``.
Returns
-------
The entries ix for the key entry_id.
"""
if ix is None:
ix = range(len(self))
trace_result = []
for iteration in ix:
try:
dataset = self._f[
f"{HISTORY}/{self.id}/{TRACE}/{iteration}/{entry_id}"
]
if dataset.shape == ():
entry = dataset[()] # scalar
else:
entry = np.array(dataset)
trace_result.append(entry)
except KeyError:
trace_result.append(None)
return trace_result
[docs]
@trace_wrap
def get_x_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[np.ndarray], np.ndarray]:
"""See :meth:`HistoryBase.get_x_trace`."""
return self._get_hdf5_entries(X, ix)
[docs]
@trace_wrap
def get_fval_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[float], float]:
"""See :meth:`HistoryBase.get_fval_trace`."""
return self._get_hdf5_entries(FVAL, ix)
[docs]
@trace_wrap
def get_grad_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[MaybeArray], MaybeArray]:
"""See :meth:`HistoryBase.get_grad_trace`."""
return self._get_hdf5_entries(GRAD, ix)
[docs]
@trace_wrap
def get_hess_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[MaybeArray], MaybeArray]:
"""See :meth:`HistoryBase.get_hess_trace`."""
return self._get_hdf5_entries(HESS, ix)
[docs]
@trace_wrap
def get_res_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[MaybeArray], MaybeArray]:
"""See :meth:`HistoryBase.get_res_trace`."""
return self._get_hdf5_entries(RES, ix)
[docs]
@trace_wrap
def get_sres_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[MaybeArray], MaybeArray]:
"""See :meth:`HistoryBase.get_sres_trace`."""
return self._get_hdf5_entries(SRES, ix)
[docs]
@trace_wrap
def get_time_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[float], float]:
"""See :meth:`HistoryBase.get_time_trace`."""
return self._get_hdf5_entries(TIME, ix)
def _editable(self) -> bool:
"""
Check whether the id is already existent in the file.
Returns
-------
True if the file is editable, False otherwise.
"""
try:
with h5py.File(self.file, "a") as f:
# editable if the id entry does not exist
if HISTORY not in f.keys() or self.id not in f[HISTORY]:
return True
return False
except OSError:
# if something goes wrong, we assume the file is not editable
return False
[docs]
@staticmethod
def from_history(
other: HistoryBase,
file: Union[str, Path],
id_: str,
overwrite: bool = False,
) -> "Hdf5History":
"""Write some History to HDF5.
Parameters
----------
other:
History to be copied to HDF5.
file:
HDF5 file to write to (append or create).
id_:
ID of the history.
overwrite:
Whether to overwrite an existing history with the same id.
Defaults to ``False``.
Returns
-------
The newly created :class:`Hdf5History`.
"""
history = Hdf5History(file=file, id=id_)
history._f = h5py.File(history.file, mode="a")
try:
if f"{HISTORY}/{history.id}" in history._f:
if overwrite:
del history._f[f"{HISTORY}/{history.id}"]
else:
raise RuntimeError(
f"ID {history.id} already exists in file {file}."
)
trace_group = history._require_group()
trace_group.attrs[N_FVAL] = other.n_fval
trace_group.attrs[N_GRAD] = other.n_grad
trace_group.attrs[N_HESS] = other.n_hess
trace_group.attrs[N_RES] = other.n_res
trace_group.attrs[N_SRES] = other.n_sres
trace_group.attrs[START_TIME] = other.start_time
trace_group.attrs[N_ITERATIONS] = (
len(other.get_time_trace()) if other.implements_trace() else 0
)
group = trace_group.parent.require_group(MESSAGES)
if other.message is not None:
group.attrs[MESSAGE] = other.message
if other.exitflag is not None:
group.attrs[EXITFLAG] = other.exitflag
if not other.implements_trace():
return history
for trace_key in (X, FVAL, GRAD, HESS, RES, SRES, TIME):
getter = getattr(other, f"get_{trace_key}_trace")
trace = getter()
for iteration, value in enumerate(trace):
trace_group.require_group(str(iteration))[
trace_key
] = value
finally:
history._f.close()
history._f = None
return history