Source code for desilike.samplers.blackjax

"""Module implementing the BlackJAX samplers."""
try:
    import jax
    import blackjax
    BLACKJAX_INSTALLED = True
except ModuleNotFoundError:
    BLACKJAX_INSTALLED = False
import numpy as np

from .base import MarkovChainSampler


def make_steps_factory(step):
    """Produce a JIT compiled version of the `make_steps` function.

    Parameters
    ----------
    step : function
        The BlackJAX kernel step function.

    Returns
    -------
    The `make_steps` function.

    """

    def make_one_step(state, rng_key):
        """Advance the sampler by one step.

        Parameters
        ----------
        state : NamedTuple
            State of the sampler.
        rng_key : jax.Array
            Random state.

        Returns
        -------
        state : NamedTuple
            New state of the sampler.
        state : NamedTuple
            Returned again for use in `jax.lax.scan`.

        """
        state, _ = step(rng_key, state)
        return state, state

    def make_steps(args):
        """Advance the state by several steps.

        Parameters
        ----------
        args : tuple
            Blackjax state and random keys. Each random key is used for one
            step.

        Returns
        -------
        final_state : NamedTuple
            Final state after all steps.
        states : NamedTuple
            All sampled states.

        """
        state, rng_keys = args
        return jax.lax.scan(make_one_step, state, rng_keys)

    return jax.jit(make_steps)


class BlackJAXSampler(MarkovChainSampler):
    """Wrapper for ``BlackJAX`` samplers.

    .. rubric:: References

    - `BlackJAX repo <https://github.com/blackjax-devs/blackjax>`_
    - `BlackJAX docs <https://blackjax-devs.github.io/blackjax/>`_

    """

    def __init__(self, likelihood, n_chains=1, chains=None, rng=None,
                 directory=None):
        """Initialize the ``BlackJAX`` sampler.

        Parameters
        ----------
        likelihood : BaseLikelihood
            Likelihood to sample.
        n_chains : int, optional
            Number of **independent** chains. Default is 1.
        chains : list of desilike.samples.Chain, optional
            If given, continue the chains. In that case, we will ignore what
            was read from disk. Default is ``None``.
        rng : numpy.random.Generator, int, or None, optional
            Random number generator. Default is ``None``.
        directory : str, Path, or None, optional
            Save samples to this location. Default is ``None``.

        Raises
        ------
        TypeError
            If called by this class.

        """
        if not BLACKJAX_INSTALLED:
            raise ImportError("The 'blackjax' package is required but not "
                              "installed.")

        if type(self) is BlackJAXSampler:
            raise TypeError("BlackJAXSampler cannot be iniated directly.")

        super().__init__(likelihood, n_chains, chains=chains, rng=rng,
                         directory=directory)

        self.compute_posterior_without_derived = self.pool.cache_function(
            lambda sample: self.likelihood(sample, return_derived=False),
            'compute_posterior_without_derived')
        self.compute_derived = self.pool.cache_function(
            jax.vmap(lambda sample: self.likelihood(
                sample, return_derived=True)[1]), 'compute_derived')

        self.kernel_type = getattr(blackjax, self.kernel_type)
        self.kernel = self.kernel_type(
            self.compute_posterior_without_derived, **self.kernel_args)
        self.adaptation_fn = getattr(blackjax, self.adaptation_fn)
        self.make_steps = self.pool.cache_function(
            make_steps_factory(self.kernel.step), 'make_steps')

    def _run(self, n_steps):
        """Run the ``BlackJAX`` sampler.

        Parameters
        ----------
        n_steps : int
            Number of steps to take.

        """
        if not hasattr(self, 'blackjax_states'):
            self.blackjax_states = []
            for i in range(self.n_chains):
                initial_position = dict(zip(
                    self.varied_params, self._state[0][i]))
                try:
                    self.blackjax_states.append(
                        self.kernel.init(initial_position))
                except TypeError:
                    rng_key = jax.random.PRNGKey(self.rng.integers(2**32))
                    self.blackjax_states.append(self.kernel.init(
                        initial_position, rng_key))

        rng_keys = jax.random.split(jax.random.PRNGKey(
            self.rng.integers(2**32)), self.n_chains)

        # Make the steps.
        inputs = [(self.blackjax_states[i], jax.random.split(
            rng_keys[i], n_steps)) for i in range(self.n_chains)]
        results = self.pool.map(self.make_steps, inputs)

        # Update the blackjax states.
        self.blackjax_states = [r[0] for r in results]

        # Update the chains.
        samples = np.vstack([np.column_stack([
            r[1].position[key] for key in self.varied_params])
            for r in results])
        log_post = np.concatenate([r[1].logdensity for r in results])

        if len(self.derived_params):
            # Recompute the derived parameters since they couldn't be saved
            # during the sampling.
            derived = self.pool.map(
                self.compute_derived, [r[1].position for r in results])
            derived = np.vstack([np.column_stack([
                d[key] for key in self.derived_params])
                for d in derived])
        else:
            derived = np.zeros((self.n_chains * n_steps, 0))

        samples = samples.reshape((self.n_chains, n_steps, -1))
        derived = derived.reshape((self.n_chains, n_steps, -1))
        log_post = log_post.reshape((self.n_chains, n_steps))
        self._extend(samples, derived, log_post)

    def _adapt(self, steps):
        """Adapt the step size and mass matrix.

        Parameters
        ----------
        steps : int
            How steps to run for the adaptation.

        """
        fixed_kernel_args = {
            key: value for key, value in self.kernel_args.items() if key not in
            self.adaptable_args}
        initial_position = {key: self.chains[0][key][-1].astype(
            jax.numpy.float64) for key in self.varied_params}
        rng_key = jax.random.PRNGKey(self.rng.integers(2**32))
        (state, parameters), _ = self.adaptation_fn(
            self.kernel_type, self.compute_posterior_without_derived,
            **fixed_kernel_args).run(
                rng_key, initial_position, num_steps=steps)
        self.kernel_args.update(parameters)
        self.kernel = self.kernel_type(
            self.compute_posterior_without_derived, **self.kernel_args)
        self.make_steps = self.pool.cache_function(
            make_steps_factory(self.kernel.step), 'make_steps')


[docs] class HMCSampler(BlackJAXSampler): """Wrapper for Hamiltonian Monte-Carlo (HMC).""" kernel_type = 'hmc' adaptable_args = ['step_size', 'inverse_mass_matrix'] adaptation_fn = 'window_adaptation' def __init__(self, likelihood, n_chains=1, chains=None, step_size=1e-3, inverse_mass_matrix=None, num_integration_steps=60, rng=None, directory=None, **kwargs): """Initialize the HMC sampler. Parameters ---------- likelihood : BaseLikelihood Likelihood to sample. n_chains : int, optional Number of **independent** chains. Default is 1. chains : list of desilike.samples.Chain or None, optional If given on rank 0, continue the chains. In that case, we will ignore what was read from disk. Default is ``None``. step_size : float, optional Size of the integration step. Default is 1e-3. inverse_mass_matrix : numpy.ndarray, optional The value to use for the inverse mass matrix when drawing a value for the momentum and computing the kinetic energy. If one-dimensional, a diagonal mass matrix is assumed. If ``None``, a unity matrix is used. Default is ``None``. num_integration_steps : int, optional Number of times we run the symplectic integrator to build the trajectory. Default is 60. rng : numpy.random.RandomState, int, or None, optional Random number generator for seeding. If ``None``, no seed is used. Default is ``None``. directory : str, Path, or None, optional Save samples to this location. Default is ``None``. **kwargs Extra keyword arguments passed to ``blackjax.hmc`` during initialization. """ if inverse_mass_matrix is None: inverse_mass_matrix = np.ones(len(likelihood.varied_params)) self.kernel_args = dict( step_size=step_size, inverse_mass_matrix=inverse_mass_matrix, num_integration_steps=num_integration_steps, **kwargs) super().__init__(likelihood, n_chains=n_chains, chains=chains, rng=rng, directory=directory)
[docs] class NoUTurnSampler(BlackJAXSampler): """Wrapper for No-U-Turn Sampler (NUTS). .. rubric:: References - `NUTS paper <https://www.jmlr.org/papers/volume15/hoffman14a/hoffman14a. pdf>`_ """ kernel_type = 'nuts' adaptable_args = ['step_size', 'inverse_mass_matrix'] adaptation_fn = 'window_adaptation' def __init__(self, likelihood, n_chains=1, chains=None, step_size=1e-3, inverse_mass_matrix=None, rng=None, directory=None, **kwargs): """Initialize the No-U-Turn Sampler. Parameters ---------- likelihood : BaseLikelihood Likelihood to sample. n_chains : int, optional Number of **independent** chains. Default is 1. chains : list of desilike.samples.Chain or None, optional If given on rank 0, continue the chains. In that case, we will ignore what was read from disk. Default is ``None``. step_size : float, optional Size of the integration step. Default is 1e-3. inverse_mass_matrix : numpy.ndarray, optional The value to use for the inverse mass matrix when drawing a value for the momentum and computing the kinetic energy. If one-dimensional, a diagonal mass matrix is assumed. If ``None``, a unity matrix is used. Default is ``None``. rng : numpy.random.RandomState, int, or None, optional Random number generator for seeding. If ``None``, no seed is used. Default is ``None``. directory : str, Path, or None, optional Save samples to this location. Default is ``None``. **kwargs Extra keyword arguments passed to ``blackjax.nuts`` during initialization. """ if inverse_mass_matrix is None: inverse_mass_matrix = np.ones(len(likelihood.varied_params)) self.kernel_args = dict( step_size=step_size, inverse_mass_matrix=inverse_mass_matrix, **kwargs) super().__init__(likelihood, n_chains=n_chains, chains=chains, rng=rng, directory=directory)
[docs] class MCLMCSampler(BlackJAXSampler): """Wrapper for the Microcanonical Langevin Monte Carlo (MCLMC) sampler. .. rubric:: References - `MCLMC docs <https://blackjax-devs.github.io/blackjax/autoapi/blackjax/\ mcmc/mclmc/index.html>`_ - `MCLMC paper <https://doi.org/10.48550/arXiv.2212.08549>`_ """ kernel_type = 'mclmc' adaptable_args = ['L', 'step_size'] adaptation_fn = 'mclmc_find_L_and_step_size' def __init__(self, likelihood, n_chains=1, chains=None, L=1., step_size=0.1, rng=None, directory=None, **kwargs): """Initialize the Microcanonical Langevin Monte Carlo (MCLMC) sampler. Parameters ---------- likelihood : BaseLikelihood Likelihood to sample. n_chains : int, optional Number of **independent** chains. Default is 1. chains : list of desilike.samples.Chain or None, optional If given on rank 0, continue the chains. In that case, we will ignore what was read from disk. Default is ``None``. L : float, optional Momentum decoherence scale. Default is 1. step_size : float, optional The value to use for the step size in the integrator. Default is 0.1. rng : numpy.random.RandomState or int, optional Random number generator. Default is ``None``. directory : str, Path, optional Save samples to this location. Default is ``None``. **kwargs Extra keyword arguments passed to ``blackjax.mclmc`` during initialization. """ self.kernel_args = dict(L=L, step_size=step_size, **kwargs) super().__init__(likelihood, n_chains=n_chains, chains=chains, rng=rng, directory=directory)