Source code for pypesto.objective.aesara.base

"""
Aesara models interface.

Adds an interface for the construction of loss functions
incorporating aesara models. This permits computation of derivatives using a
combination of objective based methods and aesara based backpropagation.
"""

import copy
from collections.abc import Sequence
from typing import Optional

import numpy as np

from ...C import FVAL, GRAD, HESS, MODE_FUN, RDATAS, ModeType
from ..base import ObjectiveBase, ResultDict

try:
    import aesara
    import aesara.tensor as aet
    from aesara.tensor import Op
    from aesara.tensor.var import TensorVariable
except ImportError:
    raise ImportError(
        "Using an aeasara objective requires an installation of "
        "the python package aesara. Please install aesara via "
        "`pip install aesara`."
    ) from None


[docs] class AesaraObjective(ObjectiveBase): """ Wrapper around an ObjectiveBase. Computes the gradient at each evaluation, caching it for later calls. Caching is only enabled after the first time the gradient is asked for and disabled whenever the cached gradient is not used, in order not to increase computation time for derivative-free samplers. Parameters ---------- objective: The `pypesto.ObjectiveBase` to wrap. aet_x: Tensor variables that define the variables of `aet_fun` aet_fun: Aesara function that maps `aet_x` to the variables of `objective` coeff: Multiplicative coefficient for objective """
[docs] def __init__( self, objective: ObjectiveBase, aet_x: TensorVariable, aet_fun: TensorVariable, coeff: Optional[float] = 1.0, x_names: Sequence[str] = None, ): if not isinstance(objective, ObjectiveBase): raise TypeError("objective must be an ObjectiveBase instance") if not objective.check_mode(MODE_FUN): raise NotImplementedError( f"objective must support mode={MODE_FUN}" ) super().__init__(x_names) self.base_objective = objective self.aet_x = aet_x self.aet_fun = aet_fun self._coeff = coeff self.obj_op = AesaraObjectiveOp(self, self._coeff) # compiled function if objective.has_fun: self.afun = aesara.function([aet_x], self.obj_op(aet_fun)) # compiled gradient if objective.has_grad: self.agrad = aesara.function( [aet_x], aesara.grad(self.obj_op(aet_fun), [aet_x]) ) # compiled hessian if objective.has_hess: self.ahess = aesara.function( [aet_x], aesara.gradient.hessian(self.obj_op(aet_fun), [aet_x]) ) # compiled input mapping self.infun = aesara.function([aet_x], aet_fun) # temporary storage for evaluation results of objective self.cached_base_ret: ResultDict = {}
[docs] def check_mode(self, mode: ModeType) -> bool: """See `ObjectiveBase` documentation.""" return mode == MODE_FUN
[docs] def check_sensi_orders(self, sensi_orders, mode: ModeType) -> bool: """See `ObjectiveBase` documentation.""" if not self.check_mode(mode): return False else: return self.base_objective.check_sensi_orders(sensi_orders, mode)
[docs] def call_unprocessed( self, x: np.ndarray, sensi_orders: tuple[int, ...], mode: ModeType, **kwargs, ) -> ResultDict: """ See `ObjectiveBase` for more documentation. Main method to overwrite from the base class. It handles and delegates the actual objective evaluation. """ # hess computation in aesara requires grad if 2 in sensi_orders and 1 not in sensi_orders: sensi_orders = (1, *sensi_orders) # this computes all the results from the inner objective, rendering # them accessible to aesara compiled functions set_return_dict, return_dict = ( "return_dict" in kwargs, kwargs.pop("return_dict", False), ) self.cached_base_ret = self.base_objective( self.infun(x), sensi_orders, mode, return_dict=True, **kwargs ) if set_return_dict: kwargs["return_dict"] = return_dict ret = {} if RDATAS in self.cached_base_ret: ret[RDATAS] = self.cached_base_ret[RDATAS] if 0 in sensi_orders: ret[FVAL] = float(self.afun(x)) if 1 in sensi_orders: ret[GRAD] = self.agrad(x)[0] if 2 in sensi_orders: ret[HESS] = self.ahess(x)[0] return ret
def __deepcopy__(self, memodict=None): other = AesaraObjective( copy.deepcopy(self.base_objective), self.aet_x, self.aet_fun, self._coeff, ) return other
class AesaraObjectiveOp(Op): """ Aesara wrapper around a (non-normalized) log-probability function. Parameters ---------- obj: Base aesara objective coeff: Multiplicative coefficient for the objective function value """ itypes = [aet.dvector] # expects a vector of parameter values when called otypes = [aet.dscalar] # outputs a single scalar value (the log prob) def __init__(self, obj: AesaraObjective, coeff: Optional[float] = 1.0): self._objective: AesaraObjective = obj self._coeff: float = coeff # initialize the sensitivity Op if obj.has_grad: self._log_prob_grad = AesaraObjectiveGradOp(obj, coeff) else: self._log_prob_grad = None def perform(self, node, inputs, outputs, params=None): # noqa # note that we use precomputed values from the outer # AesaraObjective.call_unprocessed here, which means we can # ignore inputs here log_prob = self._coeff * self._objective.cached_base_ret[FVAL] outputs[0][0] = np.array(log_prob) def grad(self, inputs, g): """ Calculate the hessian. Actually returns the vector-hessian product - g[0] is a vector of parameter values. """ if self._log_prob_grad is None: return super().grad(inputs, g) (theta,) = inputs log_prob_grad = self._log_prob_grad(theta) return [g[0] * log_prob_grad] class AesaraObjectiveGradOp(Op): """ Aesara wrapper around a (non-normalized) log-probability gradient function. This Op will be called with a vector of values and also return a vector of values - the gradients in each dimension. Parameters ---------- obj: Base aesara objective coeff: Multiplicative coefficient for the objective function value """ itypes = [aet.dvector] # expects a vector of parameter values when called otypes = [aet.dvector] # outputs a vector (the log prob grad) def __init__(self, obj: AesaraObjective, coeff: Optional[float] = 1.0): self._objective: AesaraObjective = obj self._coeff: float = coeff if obj.has_hess: self._log_prob_hess = AesaraObjectiveHessOp(obj, coeff) else: self._log_prob_hess = None def perform(self, node, inputs, outputs, params=None): # noqa # note that we use precomputed values from the outer # AesaraObjective.call_unprocessed here, which means we can # ignore inputs here log_prob_grad = self._coeff * self._objective.cached_base_ret[GRAD] outputs[0][0] = log_prob_grad def grad(self, inputs, g): """ Calculate the hessian. Actually returns the vector-hessian product - g[0] is a vector of parameter values. """ if self._log_prob_hess is None: return super().grad(inputs, g) (theta,) = inputs log_prob_hess = self._log_prob_hess(theta) return [g[0].dot(log_prob_hess)] class AesaraObjectiveHessOp(Op): """ Aesara wrapper around a (non-normalized) log-probability Hessian function. This Op will be called with a vector of values and also return a matrix of values - the Hessian in each dimension. Parameters ---------- obj: Base aesara objective coeff: Multiplicative coefficient for the objective function value """ itypes = [aet.dvector] otypes = [aet.dmatrix] def __init__(self, obj: AesaraObjective, coeff: Optional[float] = 1.0): self._objective: AesaraObjective = obj self._coeff: float = coeff def perform(self, node, inputs, outputs, params=None): # noqa # note that we use precomputed values from the outer # AesaraObjective.call_unprocessed here, which means we can # ignore inputs here log_prob_hess = self._coeff * self._objective.cached_base_ret[HESS] outputs[0][0] = log_prob_hess