Source code for desilike.samplers.base

"""
Base classes for posterior samplers.

This module defines common functions and classes that are inherited by
specialized classes implementing specific samplers such as ``emcee`` or
``dynesty``.
"""

import functools
import json
import sys
import warnings
from abc import ABC, ABCMeta, abstractmethod
from pathlib import Path

import numpy as np

from desilike import Samples
from desilike.pool import MPIPool
from desilike.statistics import diagnostics
from desilike.utils import BaseClass


def _main(func):
    """Execute function only from the main process."""

    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        exception = None
        if self.pool.main:
            try:
                result = func(self, *args, **kwargs)
            except Exception as exc:
                exception = exc
            finally:
                try:
                    self.pool.stop_wait()
                except:
                    self.mpicomm.Abort(1)
        else:
            self.pool.wait()

        exception = self.mpicomm.bcast(exception)
        if exception:
            raise exception

        return self.mpicomm.bcast(None if not self.pool.main else result)

    return wrapper


def _update_parameters(user_kwargs, sampler, **desilike_kwargs):
    """
    Update the parameter passed to a sampler.

    desilike homogenizes the interface to several samplers. In some cases, this
    requires overwriting parameters the user tries to pass to the sampler
    explicitly.

    Parameters
    ----------
    user_kwargs : dict
        Keyword arguments received from the user.
    sampler : str
        Name of the sampler. This is used to make warnings informative.
    **desilike_kwargs
        Keyword arguments enforced by desilike.

    Returns
    -------
    dict
        Updated keyword arguments.

    """
    kwargs = user_kwargs.copy()
    for key, value in desilike_kwargs.items():
        if key in user_kwargs:
            msg = f"Overwriting keyword argument '{key}' passed to {sampler} "
            warnings.warn(msg)
        kwargs[key] = value
    return kwargs


class BaseSamplerMeta(type(BaseClass), ABCMeta):
    """Metaclass combining BaseClass metaclass and ABCMeta."""

    pass


class BaseSampler(BaseClass, ABC, metaclass=BaseSamplerMeta):
    """Abstract class defining common functions used by all samplers."""

    def __init__(self, likelihood, rng=None, directory=None):
        """Initialize the sampler.

        Parameters
        ----------
        likelihood : BaseLikelihood
            Likelihood to sample.
        rng : numpy.random.Generator, int or None, optional
            Random number generator. Default is ``None``.
        directory : str, Path, or None, optional
            Save samples to this folder. Default is ``None``.

        """
        self.likelihood = likelihood
        self.varied_params = self.likelihood.varied_params.names()
        self.n_dim = len(self.varied_params)
        params = (
            self.likelihood.all_params.select(derived=True) +
            self.likelihood.all_params.select(solved=True))
        params = [param for param in params if param.name not in
                  ['loglikelihood', 'logprior']]
        self.derived_params = [param.name for param in params]
        self.derived_shapes = [param.shape for param in params]
        self.n_derived = int(sum(
            np.prod(shape) for shape in self.derived_shapes))

        self.mpicomm = self.likelihood.mpicomm
        self.pool = MPIPool()
        for name, f in zip(
                ['_prior_transform', '_compute_prior', '_compute_posterior',
                 '_compute_likelihood'],
                [self._prior_transform, self._compute_prior,
                 self._compute_posterior, self._compute_likelihood]):
            setattr(self, name, self.pool.cache_function(f, name))

        if directory is not None:
            directory = Path(directory)
            if directory.suffix:
                raise ValueError("The directory cannot have a suffix.")
            if self.pool.main:
                directory.mkdir(parents=True, exist_ok=True)
        self.directory = directory

        if self.directory is not None:
            try:
                self._load()
            except FileNotFoundError:
                pass

        if not hasattr(self, 'rng'):
            if isinstance(rng, int) or rng is None:
                rng = np.random.default_rng(seed=rng)
            self.rng = rng

        self._jit_likelihood()

    def _jit_likelihood(self):
        """JIT the likelihood with JAX, if possible."""
        rng = np.random.default_rng(seed=42)

        def get_start(size=1):
            toret = {}
            for param in self.likelihood.varied_params:
                if param.ref.is_proper():
                    value = param.ref.sample(size=size, random_state=rng)
                else:
                    value = np.full(size, param.value)
                toret[param.name] = value
            return toret

        self.likelihood()  # initialize before jit
        try:
            import jax
            likelihood = jax.jit(
                self.likelihood, static_argnames=['return_derived'])
            likelihood(get_start())
            likelihood(get_start(), return_derived=True)
            self._likelihood = likelihood
            if self.mpicomm.rank == 0:
                self.log_info("Successfully jit input likelihood.")
        except:
            if self.mpicomm.rank == 0:
                self.log_info("Could *not* jit input likelihood.")

    def _prior_transform(self, sample):
        """Transform from the unit cube to parameter space using the prior.

        Parameters
        ----------
        sample : numpy.ndarray of shape (n_dim, )
            Sample for which to perform the prior transform.

        Returns
        -------
        numpy.ndarray of shape (n_dim, )
            Prior transformation of the input sample.

        """
        return np.array([param.prior.ppf(x) for param, x in zip(
            self.likelihood.all_params, sample)])

    def _compute_prior(self, sample):
        """
        Compute the natural logarithm of the prior.

        Parameters
        ----------
        sample : numpy.ndarray of shape (n_dim, ) or dict
            Sample for which to perform the prior transform.

        Returns
        -------
        log_prior : float
            Natural logarithm of the prior.

        """
        if not isinstance(sample, dict):
            sample = dict(zip(self.varied_params, sample))
        return self.likelihood.all_params.prior(**sample)

    def _compute_posterior(self, sample):
        """Compute the natural logarithm of the posterior.

        Parameters
        ----------
        sample : numpy.ndarray of shape (n_dim, ) or dict
            Sample for which to compute the likelihood.

        Returns
        -------
        log_post : float
            Natural logarithm of the posterior.
        derived : numpy.ndarray
            Derived parameters.

        """
        if not isinstance(sample, dict):
            sample = dict(zip(self.varied_params, sample))
        log_post, derived = self._likelihood(sample, return_derived=True)
        derived = np.concatenate([
            np.asarray(derived[key]).flatten() for key in self.derived_params])

        return float(log_post), derived

    def _compute_likelihood(self, sample):
        """Compute the natural logarithm of the likelihood.

        Parameters
        ----------
        sample : numpy.ndarray of shape (n_dim, ) or dict
            Sample for which to compute the likelihood.

        Returns
        -------
        log_l : float
            Natural logarithm of the likelihood.
        derived : numpy.ndarray
            Derived parameters.

        """
        log_prior = self._compute_prior(sample)
        log_post, derived = self._compute_posterior(sample)

        return log_post - log_prior, derived

    def _array_to_samples(self, samples, derived, **kwargs):
        """Convert NumPy arrays to desilike chains.

        Parameters
        ----------
        samples : numpy.ndarray of shape (n_samples, n_dim)
            Samples of varied parameters.
        derived : numpy.ndarray of shape (n_samples, n_derived)
            Samples of derived parameters.
        **kwargs
            Extra parameters such as weights.

        Returns
        -------
        samples : desilike.Samples
            Samples with all derived parameters, weights, etc.

        """
        samples = dict(zip(self.varied_params, samples.T))

        derived = np.split(derived, np.cumsum([
            int(np.prod(shape)) for shape in self.derived_shapes])[:-1],
            axis=1)
        derived = [derived[i].reshape((-1, ) + shape) for i, shape in
                   enumerate(self.derived_shapes)]
        derived = dict(zip(self.derived_params, derived))

        samples = Samples(**(samples | derived))
        for key, value in kwargs.items():
            samples[key] = value

        return samples

    def _save(self):
        """Write all results to disk."""
        if self.pool.main:
            with open(self.directory / 'rng.json', 'w') as fstream:
                json.dump(self.rng.bit_generator.state, fstream)

    def _load(self):
        """Read internal calculations from disk."""
        if self.pool.main:
            with open(self.directory / 'rng.json', 'r') as fstream:
                self.rng = np.random.default_rng()
                self.rng.bit_generator.state = json.load(fstream)


[docs] class StaticSampler(BaseSampler): """Class defining common functions used by static samplers."""
[docs] @abstractmethod def get_samples(self, **kwargs): """Abstract method to get the samples to be evaluated. Parameters ---------- **kwargs Extra keyword arguments. Returns ------- numpy.ndarray of shape (n_samples, n_dim) Samples in parameter space to evaluate. """ pass
[docs] @_main def run(self, **kwargs): """Run the sampler. Parameters ---------- **kwargs Keyword arguments passed to the ``get_samples`` method. Returns ------- samples : desilike.Samples Posterior samples. """ if not hasattr(self, 'results'): # Do the calculations. samples = self.get_samples(**kwargs) log_prior = np.array(self.pool.map(self._compute_prior, samples)) results = self.pool.map(self._compute_posterior, samples) log_posterior = np.array([r[0] for r in results]) derived = np.array([r[1] for r in results]) self.samples = self._array_to_samples( samples, derived, log_posterior=log_posterior, log_weight=log_posterior, log_prior=log_prior) if self.directory is not None: self._save() return self.samples
def _save(self): """Write internal calculations to disk.""" if self.pool.main: self.samples.save(self.directory / 'samples.npz') def _load(self): """Read internal calculations from disk.""" if self.pool.main: self.samples = Samples.load(self.directory / 'samples.npz')
[docs] class PopulationSampler(BaseSampler): """Class defining common functions used by population samplers.""" @abstractmethod def _run(self, **kwargs): """Run a specific sampler. Parameters ---------- **kwargs Extra keyword arguments passed to sampler's run method. Returns ------- samples : numpy.ndarray of shape (n_samples, n_dim) Samples of varied parameters. derived : numpy.ndarray Samples of derived parameters. extras : dict Extra parameters such as weights. """ pass
[docs] @_main def run(self, **kwargs): """Run the sampler. Parameters ---------- **kwargs Keyword arguments passed to the run function of the sampler. Returns ------- samples : desilike.Samples Posterior samples. """ samples, derived, extras = self._run(**kwargs) return self._array_to_samples(samples, derived, **extras)
[docs] class MarkovChainSampler(BaseSampler): """Class defining common functions used by Markov chain samplers.""" default_adaptation_steps = 0 def __init__(self, likelihood, n_chains=1, chains=None, rng=None, directory=None): """Initialize the sampler. Parameters ---------- likelihood : BaseLikelihood Likelihood to sample. n_chains : int, optional Number of **independent** chains. Default is 1. chains : list of desilike.samples.Chain or None, optional If given on rank 0, continue the chains. In that case, we will ignore what was read from disk. Default is ``None``. rng : numpy.random.Generator, int, or None, optional Random number generator. Default is ``None``. directory : str, Path, or None, optional Save samples to this location. Default is ``None``. Raises ------ ValueError If ``burn_in`` is a float and larger than unity. """ if chains is None: self.n_chains = n_chains else: self.n_chains = len(chains) super().__init__(likelihood, rng=rng, directory=directory) if chains is not None: self.chains = chains self.checks = [] if not hasattr(self, 'chains'): self.chains = [] self.checks = [] @abstractmethod def _run(self, n_steps): """Run a specific sampler. Parameters ---------- n_steps : int How many additional steps to run. """ pass @abstractmethod def _adapt(self, n_steps): """Adapt a specific sampler. Parameters ---------- n_steps : int How steps to run for the adaptation. """ pass def _initialize(self, max_init_attempts=100): """Initialize the chains. Parameters ---------- max_init_attempts : int or None, optional Maximum number of attempts per chain. If ``None``, there is no limit. Default is 100. Raises ------ ValueError If no finite posterior has been found after ``max_init_attempts`` attempts. """ if max_init_attempts is None: max_init_attempts = sys.maxsize for i in range(1, max_init_attempts + 1): # Draw random samples. samples = np.zeros((self.n_chains, self.n_dim)) for i, param in enumerate(self.likelihood.varied_params): if param.ref.is_proper(): samples[:, i] = param.ref.sample( size=self.n_chains, random_state=self.rng) else: samples[:, i] = np.full(self.n_chains, param.value) results = self.pool.map(self._compute_posterior, samples) log_post = np.array([r[0] for r in results]) derived = np.array([r[1] for r in results]) # Accept those with finite posterior. for i in np.arange(self.n_chains)[np.isfinite(log_post)]: chain = self._array_to_samples( np.atleast_2d(samples[i]), np.atleast_2d(derived[i]), log_posterior=np.atleast_1d(log_post[i])) self.chains.append(chain) if len(self.chains) >= self.n_chains: break if i == max_init_attempts: msg = f"Could not find finite posterior after {i} attempts." raise ValueError(msg) @property def _state(self): """Return the current state of the chains as NumPy arrays. Returns ------- samples : numpy.ndarray of shape (n_chains, n_dim) Current position of the chains. derived : numpy.ndarray of shape (n_chain, n_derived) Current derived paramters. log_post : numpy.ndarray of shape (n_chains, ) Current logarithm of the posterior. """ samples = [[chain[key][-1] for key in self.varied_params] for chain in self.chains] derived = [np.concatenate([ np.asarray(chain[key][-1]).flatten() for key in self.derived_params]) for chain in self.chains] log_post = [chain['log_posterior'][-1] for chain in self.chains] return np.array(samples), np.array(derived), np.array(log_post) def _extend(self, samples, derived, log_post): """Extend the sampler chains. Parameters ---------- samples : numpy.ndarray of shape (n_chains, n_steps, n_dim) Positions in parameter space. derived : numpy.ndarray of shape (n_chains, n_steps, ...) Blobs returned from the posterior. log_post : numpy.ndarray of shape (n_chains, n_steps) Logarithm of the posterior. """ for i in range(self.n_chains): chain = self._array_to_samples( samples[i], derived[i], log_posterior=log_post[i]) self.chains[i].append(chain) def _check(self, burn_in=0.2, gelman_rubin=1.1, geweke=None, ess=None, quiet=False): """Check the status of the sampling. This function will also output the status of the analysis to the log. Parameters ---------- burn_in: float or int, optional Fraction of samples to remove from each chain. If an integer, number of iterations(steps) to remove. Default is 0.2. gelman_rubin : float or None If given, the maximum value of the Gelman-Rubin statistic. Default is 1.1. ess : float or None If given, the minimum effective sample size per chain. The effective sample size is the number of chain elements divided by the autocorrelation time. Default is ``None``. Returns ------- passed : bool Whether the chains passed all convergence checks. """ if isinstance(burn_in, float): burn_in = int(burn_in) * len(self.chains[0]) chains = [chain[burn_in:] for chain in self.chains] self.log_info('Diagnostics:') if len(chains) == 1: nsplits = 4 elif len(chains) < 4: nsplits = 2 else: nsplits = None if gelman_rubin != float('inf') and getattr(self, 'ensemble', False): msg = "Gelman-Rubin is not strictly valid for ensemble samplers." warnings.warn(msg) gelman_rubin_value = max(diagnostics.gelman_rubin( chains, keys=self.varied_params).values()) tau = max(diagnostics.integrated_autocorrelation_time( chains, keys=self.varied_params).values()) ess_value = len(chains[0]) / tau passed_all = True for name, threshold, upper, value in zip( ["Gelman-Rubin", "Effective Sample Size"], [gelman_rubin, ess], [True, False], [gelman_rubin_value, ess_value]): self.log_info(f"{name}: {value:.3g}") if threshold is not None: passed = value < threshold if upper else value >= threshold passed_all = passed_all and passed self.log_info( f"{value:.3g} {'<' if value < threshold else '>='} " f"{threshold:.3g} ({'' if passed else 'not '}passed)") return passed_all
[docs] @_main def run(self, burn_in=0.2, min_steps=0, max_steps=None, adaptation_steps=None, check_every=300, checks_passed=2, gelman_rubin=1.1, ess=None, concatenate=True, save_every=300, max_init_attempts=100): """Run the sampler. Parameters ---------- burn_in: float or int, optional Fraction of samples to remove from each chain. If an integer, number of iterations(steps) to remove. Default is 0.2. min_steps: int, optional Minimum number of steps to run. Default is 0. max_steps: int or None, optional Maximum number of steps to run. If ``None``, no limit is applied. Default is ``None``. adaptation_steps: int, optional Number of learning steps for samplers that can learn effective hyperparameters online. These samplers include Metropolis-Hastings MCMC, HMC, NUTS, and MCLMC. If ``None``, use the sampler-specific default value. Default is ``None``. check_every: int, optional After how many steps convergence is checked. Default is 300. checks_passed: int, optional Threshold for the number of successive successful convergence checks. If fulfilled (and the minimum number of iterations is reached), the sampling will stop. Default is 2. gelman_rubin: float or None Used to asses convergence. If given, the maximum value of the Gelman-Rubin statistic. Default is 1.1. ess: float or None Used to asses convergence. If given, the minimum effective sample size per chain. The effective sample size is the number of chain elements divided by the autocorrelation time. Default is ``None``. concatenate: bool, optional Whether to concatenate individual chains into one chain. Default is ``True``. save_every: int, optional After how many steps results are saved. Default is 300. max_init_attempts: int, optional Maximum number of attempts to initialize each chain. Default is 100. Returns ------- samples : desilike.Samples or list of desilike.Samples Posterior chains. """ if len(self.chains) == 0: self._initialize(max_init_attempts=max_init_attempts) if self.directory is None: save_every = check_every # Don't stop to save. if adaptation_steps is None: adaptation_steps = self.default_adaptation_steps self.adaptation_steps = adaptation_steps # only used for MH MCMC if adaptation_steps > 0: self._adapt(adaptation_steps) # Run the chain until convergence. n_steps = len(self.chains[0]) if max_steps is None: max_steps = sys.maxsize while n_steps < max_steps: if (n_steps >= min_steps and len(self.checks) >= checks_passed and all(self.checks[-checks_passed:])): break # Advance the sampler and do convergence checks. n_steps_next = min(check_every - (n_steps % check_every), save_every - (n_steps % save_every), max_steps - n_steps) n_steps += n_steps_next self._run(n_steps_next) if n_steps % check_every == 0: self.checks.append(self._check( burn_in=burn_in, gelman_rubin=gelman_rubin, ess=ess)) # Write results. if self.directory is not None and n_steps % save_every == 0: self._save() # Write results in case it wasn't written in the last iteration. if self.directory is not None and n_steps % save_every != 0: self._save() if isinstance(burn_in, float): burn_in = int(burn_in) * len(self.chains[0]) chains = [chain[burn_in:] for chain in self.chains] if concatenate: chains = Samples.concatenate(chains) return chains
def _save(self): """Write all results to disk.""" super()._save() if self.mpicomm.rank == 0: for i, chain in enumerate(self.chains): chain.save(self.directory / f'chain_{i + 1}.npz') np.save(self.directory / 'checks.npy', self.checks) def _load(self): """Read internal calculations from disk.""" super()._load() if self.mpicomm.rank == 0: self.chains = [Samples.load(self.directory / f'chain_{i + 1}.npz') for i in range(self.n_chains)] self.checks = list(np.load(self.directory / 'checks.npy'))