pypesto.objective.jax
Jax objective
- class pypesto.objective.jax.JaxObjective[source]
Bases:
ObjectiveBase
Objective function that enables use of pypesto objectives in jax models.
The generated function should generally be compatible with jax, but cannot compute higher order derivatives and is not vectorized (but still compatible with jax.vmap)
- Parameters:
objective (
ObjectiveBase
) – pyPESTO objective to be wrapped.
Note
Currently only implements MODE_FUN and sensi_orders<=1. Support for MODE_RES should be straightforward to add.
- __call__(x, sensi_orders=(0,), mode='mode_fun', return_dict=False, **kwargs)[source]
See
ObjectiveBase
for more documentation.Note that this function delegates pre- and post-processing as well as history handling to the inner objective.
- __init__(objective)[source]
- Parameters:
objective (ObjectiveBase)
- call_unprocessed(x, sensi_orders, mode, return_dict, **kwargs)[source]
See
ObjectiveBase
for more documentation.This function is not implemented for JaxObjective as it is not called in the override for __call__. However, it’s marked as abstract so we need to implement it.
- property history
Expose the history of the inner objective.
- property pre_post_processor
Expose the pre_post_processor of inner objective.
- property x_names
Expose the x_names of inner objective.