"""Module implementing the zeus sampler."""
import warnings
import numpy as np
try:
import zeus
ZEUS_INSTALLED = True
except ModuleNotFoundError:
ZEUS_INSTALLED = False
from .base import _update_parameters, MarkovChainSampler
[docs]
class ZeusSampler(MarkovChainSampler):
"""Wrapper for the ensemble slice sampler ``zeus``.
.. rubric:: References
- `zeus repo <https://github.com/minaskar/zeus>`_
- `zeus docs <https://zeus-mcmc.readthedocs.io>`_
- `zeus paper A <https://doi.org/10.48550/arXiv.2002.06212>`_
- `zeus paper B <https://doi.org/10.1093/mnras/stab2867>`_
"""
ensemble = True
def __init__(self, likelihood, n_walkers=10, chains=None, rng=None,
directory=None, **kwargs):
"""Initialize the ``zeus`` sampler.
Parameters
----------
likelihood : BaseLikelihood
Likelihood to sample.
n_walkers : int, optional
Number of walkers. Note that each walker produces a chain but
different chains are not strictly independent. Default is 10.
chains : list of desilike.samples.Chain, optional
If given, continue the chains. In that case, we will ignore what
was read from disk. Default is ``None``.
rng : numpy.random.Generator, int, or None, optional
Random number generator. Default is ``None``.
directory : str, Path, or None, optional
Save samples to this location. Default is ``None``.
**kwargs
Extra keyword arguments passed to ``zeus`` during initialization.
"""
if not ZEUS_INSTALLED:
raise ImportError("The 'zeus-mcmc' package is required but not "
"installed.")
super().__init__(likelihood, n_chains=n_walkers, chains=chains,
rng=rng, directory=directory)
if self.mpicomm.rank == 0:
kwargs = _update_parameters(
kwargs, 'zeus', nwalkers=n_walkers, ndim=self.n_dim,
logprob_fn=self._compute_posterior, pool=self.pool, args=None,
kwargs=None, vectorize=False)
self.sampler = zeus.EnsembleSampler(**kwargs)
if rng is not None:
warnings.warn("Zeus does not support random seeds. Results "
"are not deterministic.")
def _run(self, n_steps):
"""Run the ``zeus`` sampler.
Parameters
----------
n_steps: int
Number of steps to take.
"""
start, blobs0, log_prob0 = self._state
samples = np.zeros((self.n_chains, n_steps, self.n_dim))
derived = np.zeros((self.n_chains, n_steps, self.n_derived))
log_post = np.zeros((self.n_chains, n_steps))
for i, state in enumerate(self.sampler.sample(
start, log_prob0=log_prob0, blobs0=np.squeeze(blobs0),
iterations=n_steps, progress=False)):
samples[:, i, :] = state[0]
derived[:, i, :] = state[2].reshape(self.n_chains, -1)
log_post[:, i] = state[1]
self._extend(samples, derived, log_post)