pypesto.objective.jax

Jax objective

class pypesto.objective.jax.JaxObjective(objective: ObjectiveBase, jax_fun: Callable, x_names: Sequence[str] | None = None)[source]

Bases: ObjectiveBase

Objective function that combines pypesto objectives with jax functions.

The generated objective function will evaluate objective(jax_fun(x)).

Parameters:
  • objective – pyPESTO objective

  • jax_fun – jax function (not jitted) that computes input to the pyPESTO objective

__init__(objective: ObjectiveBase, jax_fun: Callable, x_names: Sequence[str] | None = None)[source]
cached_fval(_)[source]

Return cached function value.

cached_grad(_)[source]

Return cached gradient.

cached_hess(_)[source]

Return cached Hessian.

call_unprocessed(x: ndarray, sensi_orders: Tuple[int, ...], mode: Literal['mode_fun', 'mode_res'], **kwargs) Dict[str, float | ndarray | Dict][source]

See ObjectiveBase for more documentation.

Main method to overwrite from the base class. It handles and delegates the actual objective evaluation.

check_mode(mode: Literal['mode_fun', 'mode_res']) bool[source]

See ObjectiveBase documentation.

check_sensi_orders(sensi_orders, mode: Literal['mode_fun', 'mode_res']) bool[source]

See ObjectiveBase documentation.