Source code for pypesto.sample.emcee

from typing import List, Union

import numpy as np

from ..problem import Problem
from ..result import McmcPtResult
from ..startpoint import UniformStartpoints
from .sampler import Sampler

    import emcee
except ImportError:
    emcee = None

[docs]class EmceeSampler(Sampler): """Use emcee for sampling. Wrapper around, see there for details. """
[docs] def __init__( self, nwalkers: int = 1, sampler_args: dict = None, run_args: dict = None, ): """ Initialize sampler. Parameters ---------- nwalkers: The number of walkers in the ensemble. sampler_args: Further keyword arguments that are passed on to ``emcee.EnsembleSampler.__init__``. run_args: Further keyword arguments that are passed on to ``emcee.EnsembleSampler.run_mcmc``. """ # check dependencies if emcee is None: raise ImportError( "This sampler requires an installation of emcee. Install e.g. " "via ``pip install pypesto[emcee]``." ) super().__init__() self.nwalkers: int = nwalkers if sampler_args is None: sampler_args = {} self.sampler_args: dict = sampler_args if run_args is None: run_args = {} self.run_args: dict = run_args # set in initialize self.problem: Union[Problem, None] = None self.sampler: Union[emcee.EnsembleSampler, None] = None self.state: Union[emcee.State, None] = None
[docs] def initialize( self, problem: Problem, x0: Union[np.ndarray, List[np.ndarray]], ) -> None: """Initialize the sampler.""" self.problem = problem # extract for pickling efficiency objective = self.problem.objective lb = ub = self.problem.ub # parameter dimenstion ndim = len(self.problem.x_free_indices) def log_prob(x): """Log-probability density function.""" # check if parameter lies within bounds if any(x < lb) or any(x > ub): return -np.inf # invert sign return -1.0 * objective(x) # initialize sampler self.sampler = emcee.EnsembleSampler( nwalkers=self.nwalkers, ndim=ndim, log_prob_fn=log_prob, **self.sampler_args, ) # assign startpoints if self.state is None: # extract x0 x0 = np.asarray(x0) if x0.ndim == 1: x0 = [x0] x0 = np.array([problem.get_full_vector(x) for x in x0]) # add x0 to guesses problem.x_guesses_full = np.row_stack((x0, problem.x_guesses_full)) # sample start points self.state = UniformStartpoints( use_guesses=True, check_fval=True, check_grad=False, )( n_starts=self.nwalkers, problem=problem, ) # restore original guesses problem.x_guesses_full = problem.x_guesses_full[x0.shape[0] :]
[docs] def sample(self, n_samples: int, beta: float = 1.0) -> None: """Return the most recent sample state.""" self.state = self.sampler.run_mcmc( self.state, n_samples, **self.run_args )
[docs] def get_samples(self) -> McmcPtResult: """Get the samples into the fitting pypesto format.""" # all walkers are concatenated, yielding a flat array trace_x = np.array([self.sampler.get_chain(flat=True)]) trace_neglogpost = np.array([-self.sampler.get_log_prob(flat=True)]) # the sampler does not know priors trace_neglogprior = np.full(trace_neglogpost.shape, np.nan) # the walkers all run on temperature 1 betas = np.array([1.0]) result = McmcPtResult( trace_x=trace_x, trace_neglogpost=trace_neglogpost, trace_neglogprior=trace_neglogprior, betas=betas, ) return result