pypesto.objective.jax
Jax objective
- class pypesto.objective.jax.JaxObjective[source]
Bases:
ObjectiveBase
Objective function that combines pypesto objectives with jax functions.
The generated objective function will evaluate objective(jax_fun(x)).
- Parameters:
objective (
ObjectiveBase
) – pyPESTO objectivejax_fun (
Callable
) – jax function (not jitted) that computes input to the pyPESTO objective
- __init__(objective, jax_fun, x_names=None)[source]
- Parameters:
objective (ObjectiveBase) –
jax_fun (Callable) –
- call_unprocessed(x, sensi_orders, mode, **kwargs)[source]
See ObjectiveBase for more documentation.
Main method to overwrite from the base class. It handles and delegates the actual objective evaluation.