pypesto.objective.jax
Jax objective
- class pypesto.objective.jax.JaxObjective(objective: ObjectiveBase, jax_fun: Callable, x_names: Sequence[str] | None = None)[source]
Bases:
ObjectiveBase
Objective function that combines pypesto objectives with jax functions.
The generated objective function will evaluate objective(jax_fun(x)).
- Parameters:
objective – pyPESTO objective
jax_fun – jax function (not jitted) that computes input to the pyPESTO objective
- __init__(objective: ObjectiveBase, jax_fun: Callable, x_names: Sequence[str] | None = None)[source]