pypesto.objective.jax

Jax objective

class pypesto.objective.jax.JaxObjective[source]

Bases: ObjectiveBase

Objective function that combines pypesto objectives with jax functions.

The generated objective function will evaluate objective(jax_fun(x)).

Parameters:
  • objective (ObjectiveBase) – pyPESTO objective

  • jax_fun (Callable) – jax function (not jitted) that computes input to the pyPESTO objective

__init__(objective, jax_fun, x_names=None)[source]
Parameters:
cached_fval(_)[source]

Return cached function value.

cached_grad(_)[source]

Return cached gradient.

cached_hess(_)[source]

Return cached Hessian.

call_unprocessed(x, sensi_orders, mode, **kwargs)[source]

See ObjectiveBase for more documentation.

Main method to overwrite from the base class. It handles and delegates the actual objective evaluation.

Return type:

Dict[str, Union[float, ndarray, Dict]]

Parameters:
check_mode(mode)[source]

See ObjectiveBase documentation.

Return type:

bool

Parameters:

mode (Literal['mode_fun', 'mode_res']) –

check_sensi_orders(sensi_orders, mode)[source]

See ObjectiveBase documentation.

Return type:

bool

Parameters:

mode (Literal['mode_fun', 'mode_res']) –