pypesto.objective.jax

Jax objective

class pypesto.objective.jax.JaxObjective[source]

Bases: ObjectiveBase

Objective function that enables use of pypesto objectives in jax models.

The generated function should generally be compatible with jax, but cannot compute higher order derivatives and is not vectorized (but still compatible with jax.vmap)

Parameters:

objective (ObjectiveBase) – pyPESTO objective to be wrapped.

Note

Currently only implements MODE_FUN and sensi_orders=(0,). Support for MODE_RES should be straightforward to add.

__call__(x, sensi_orders=(0,), mode='mode_fun', return_dict=False, **kwargs)[source]

See ObjectiveBase for more documentation.

Note that this function delegates pre- and post-processing as well as history handling to the inner objective.

Return type:

Union[Array, tuple, dict[str, Union[float, ndarray, dict]]]

Parameters:
  • x (Array)

  • sensi_orders (tuple[int, ...])

  • mode (Literal['mode_fun', 'mode_res'])

  • return_dict (bool)

__init__(objective)[source]
Parameters:

objective (ObjectiveBase)

call_unprocessed(x, sensi_orders, mode, **kwargs)[source]

See ObjectiveBase for more documentation.

This function is not implemented for JaxObjective as it is not called in the override for __call__. However, it’s marked as abstract so we need to implement it.

Return type:

dict[str, Union[float, ndarray, dict]]

Parameters:
check_mode(mode)[source]

See ObjectiveBase documentation.

Return type:

bool

Parameters:

mode (Literal['mode_fun', 'mode_res'])

check_sensi_orders(sensi_orders, mode)[source]

See ObjectiveBase documentation.

Return type:

bool

Parameters:

mode (Literal['mode_fun', 'mode_res'])

property history

Exposes the history of the inner objective.

property pre_post_processor

Exposes the pre_post_processor of inner objective.

property x_names

Exposes the x_names of inner objective.