Source code for desilike.likelihoods.base

import warnings

import numpy as np
import lsstypes as types

from desilike.base import BaseCalculator, Parameter, ParameterCollection, ParameterArray
from desilike.observables import ObservableCovariance
from desilike.jax import numpy as jnp
from desilike.jax import jit
from desilike import plotting, utils


@jit
def chi2(flatdiff, precision):
    if precision.ndim == 1:
        return (flatdiff * precision).dot(flatdiff.T)
    return flatdiff.dot(precision).dot(flatdiff.T)



class FastFisher(object):

    alltogether = False

    def __init__(self, this, solved_params):  # this: current likelihood self
        pipeline = this.runtime_info.pipeline
        self.solved_params = ParameterCollection(solved_params)

        likelihoods = getattr(this, 'likelihoods', [this])

        def get_params(likelihood):

            calculators = []
            def callback(calculator):
                if calculator in calculators:
                    return
                calculators.append(calculator)
                for require in calculator.runtime_info.requires:
                    callback(require)

            callback(likelihood)
            return sum(calculator.runtime_info.params for calculator in calculators)

        self.solve_likelihoods = []
        likelihood_solved_params, solved_params_friends = [], {param.name: set() for param in self.solved_params}
        for likelihood in likelihoods:
            likelihood_params = get_params(likelihood)
            solved_params = ParameterCollection([param for param in likelihood_params if param in self.solved_params])
            if solved_params:
                self.solve_likelihoods.append(likelihood)
                likelihood_solved_params.append(solved_params)
                solved_names = solved_params.names()
                for name in solved_names:
                    solved_params_friends[name] |= set(solved_names)

        if self.alltogether:
            group_solve_likelihoods = [list(self.solve_likelihoods)]
        else:

            def get_all_levels_of_friends(friends_dict, person):
                from collections import deque

                if person not in friends_dict:
                    return []

                # Initialize a queue for BFS and a set for visited nodes
                queue = deque([person])
                visited = set([person])

                all_friends = set([person])  # To store all levels of friends

                # Perform BFS
                while queue:
                    current_person = queue.popleft()
                    # Get the direct friends of the current person
                    if current_person in friends_dict:
                        for friend in friends_dict[current_person]:
                            if friend not in visited:
                                visited.add(friend)  # Mark as visited
                                queue.append(friend)  # Add to queue for further exploration
                                all_friends.add(friend)  # Add to all_friends set

                return all_friends

            solved_params_groups = []
            for param in solved_params_friends:
                group = get_all_levels_of_friends(solved_params_friends, param)
                if group not in solved_params_groups:
                    solved_params_groups.append(group)

            group_solve_likelihoods = [[] for i in solved_params_groups]
            for likelihood, solved_params in zip(self.solve_likelihoods, likelihood_solved_params):
                param = solved_params[0].name
                for igroup, params_group in enumerate(solved_params_groups):
                    if param in params_group:
                        group_solve_likelihoods[igroup].append(likelihood)
                        break

        self.ilikelihood_solved_indices = [None for i in self.solve_likelihoods]
        self._group_solved_params, self._group_solved_indices, self._all_params_group = [], [], {}
        self._group_solve_likelihoods, self._group_solve_group_likelihoods_indices = [], []
        for igroup, likelihoods in enumerate(group_solve_likelihoods):
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore', message='.*Derived parameter.*')
                likelihood = SumLikelihood(likelihoods)
                likelihood.mpicomm = this.mpicomm
                likelihood.runtime_info.pipeline.more_initialize = None
                likelihood.runtime_info.pipeline.more_calculate = lambda: None
                all_params = likelihood.all_params
                for param in pipeline.params:
                    if param in likelihood.all_params:
                        param = param.clone(derived=False if param in self.solved_params or param.depends else param.derived, fixed=param not in self.solved_params)
                        all_params.set(param)
                likelihood.all_params = all_params
            input_params = [param for param in likelihood.all_params if param.name in pipeline.input_values]
            values = {param.name: pipeline.input_values[param.name] for param in input_params}
            likelihood.runtime_info.pipeline.input_values = values
            self._group_solve_likelihoods.append(likelihood)
            self._group_solve_group_likelihoods_indices.append([self.solve_likelihoods.index(likelihood) for likelihood in likelihoods])
            for ilike in self._group_solve_group_likelihoods_indices[-1]:
                self.ilikelihood_solved_indices[ilike] = np.array([iparam for iparam, param in enumerate(self.solved_params) if param in likelihood.all_params])
            self._group_solved_params.append(ParameterCollection([param for param in likelihood.all_params if param in self.solved_params]))
            self._group_solved_indices.append(np.array([self.solved_params.index(param) for param in self._group_solved_params[-1]]))
            for param in likelihood.all_params:
                self._all_params_group[param.name] = self._all_params_group.get(param.name, []) + [igroup]
        self.all_params = sum(likelihood.all_params for likelihood in self._group_solve_likelihoods)
        self.input_params = ParameterCollection([param for param in self.all_params if param.name in pipeline.input_values])

    def __call__(self, values, gradient=True):
        import jax

        def _get_list():
            return [None for like in self.solve_likelihoods]

        likelihoods_gradient, likelihoods_hessian, likelihoods_flatdiff, likelihoods_flatderiv = (_get_list() for i in range(4))
        values_ilikelihood = [{} for igroup in range(len(self._group_solve_likelihoods))]

        for param, value in values.items():
            for igroup in self._all_params_group[param]:
                values_ilikelihood[igroup][param] = value

        multiple_groups = len(self._group_solve_likelihoods) > 1
        if multiple_groups:
            nsolved = len(self.solved_params)
            x, dx = jnp.zeros(nsolved), jnp.zeros(nsolved)
            posterior_hessian, prior_hessian = jnp.zeros((nsolved, nsolved)), jnp.zeros((nsolved, nsolved))

        for igroup, likelihood in enumerate(self._group_solve_likelihoods):
            diff_names = self._group_solved_params[igroup].names()
            all_values = values_ilikelihood[igroup]

            def getter(diff_values):
                likelihood({**all_values, **dict(zip(diff_names, diff_values))})
                return [likelihood.flatdiff for likelihood in likelihood.likelihoods]

            diff_values = jnp.array([all_values[name] for name in diff_names])
            if gradient: flatderivs = jax.jacfwd(getter, argnums=0, has_aux=False, holomorphic=False)(diff_values)
            flatdiffs = getter(diff_values)

            group_likelihoods_indices = self._group_solve_group_likelihoods_indices[igroup]
            for idx, flatdiff in zip(group_likelihoods_indices, flatdiffs):
                likelihoods_flatdiff[idx] = flatdiff.T

            if gradient:
                for idx, flatdiff, flatderiv, like in zip(group_likelihoods_indices, flatdiffs, flatderivs, likelihood.likelihoods):
                    precision = like.precision
                    flatderiv = flatderiv.T
                    likelihoods_flatderiv[idx] = flatderiv
                    if precision.ndim == 1:
                        derivp = flatderiv * precision
                    else:
                        derivp = flatderiv.dot(precision)
                    likelihoods_gradient[idx] = - derivp.dot(flatdiff.T)
                    likelihoods_hessian[idx] = - derivp.dot(flatderiv.T)

                group_prior_gradient, group_prior_hessian = [], []
                group_params = self._group_solved_params[igroup]
                for param in group_params:
                    value = all_values[param.name]
                    loc, scale = getattr(param.prior, 'loc', 0.), getattr(param.prior, 'scale', np.inf)
                    prec = scale**(-2)
                    group_prior_gradient.append(- (value - loc) * prec)
                    group_prior_hessian.append(- prec)
                group_prior_gradient = jnp.array(group_prior_gradient)
                group_prior_hessian = jnp.diag(jnp.array(group_prior_hessian))
                group_posterior_gradient = group_prior_gradient + sum(likelihoods_gradient[idx] for idx in group_likelihoods_indices)
                group_posterior_hessian = group_prior_hessian + sum(likelihoods_hessian[idx] for idx in group_likelihoods_indices)
                group_dx = - jnp.linalg.solve(group_posterior_hessian, group_posterior_gradient)
                group_x = jnp.array([all_values[param.name] for param in group_params]) + group_dx
                group_indices = self._group_solved_indices[igroup]
                if multiple_groups:
                    dx = dx.at[group_indices].set(group_dx)
                    x = x.at[group_indices].set(group_x)
                    posterior_hessian = posterior_hessian.at[np.ix_(group_indices, group_indices)].set(group_posterior_hessian)
                    prior_hessian = prior_hessian.at[np.ix_(group_indices, group_indices)].set(group_prior_hessian)
                else:
                    dx, x, posterior_hessian, prior_hessian = group_dx, group_x, group_posterior_hessian, group_prior_hessian
        if gradient:
            return x, dx, posterior_hessian, prior_hessian, likelihoods_hessian, likelihoods_gradient, likelihoods_flatdiff, likelihoods_flatderiv
        return likelihoods_flatdiff


[docs] class BaseLikelihood(BaseCalculator): """Base class for likelihood.""" _attrs = ['loglikelihood', 'logprior'] name = None solved_default = '.marg' def initialize(self, catch_errors=None, **kwargs): if 'name' in kwargs: self.name = kwargs['name'] for name in self._attrs: if name not in self.params.basenames(): self.params.set(Parameter(basename=name, namespace=self.name, latex=utils.outputs_to_latex(name), derived=True)) param = self.params.select(basename=name) if not len(param): raise ValueError('{} derived parameter not found'.format(name)) elif len(param) > 1: raise ValueError('Several parameters with name {:0} found. Which one is the {:0}?'.format(name)) param = param[0] param.update(derived=True) setattr(self, '_param_{}'.format(name), param) if catch_errors is not None: catch_errors = tuple(catch_errors) self._catch_errors = catch_errors #self.__fisher = None def more_initialize(self): pipeline = self.runtime_info.pipeline likelihoods = getattr(self, 'likelihoods', [self]) # Reset precision and flatdata for likelihood in likelihoods: pipeline_initialize = getattr(likelihood, '_pipeline_initialize', None) if pipeline_initialize is not None: pipeline_initialize(pipeline) self._marginalize_precision() pipeline.more_calculate = self._solve
[docs] def get(self): pipeline = self.runtime_info.pipeline self.logprior = pipeline.params.prior(**pipeline.input_values) # does not include solved params return self.loglikelihood + self.logprior
@property def catch_errors(self): toret = getattr(self, '_catch_errors', None) if toret is None: self._catch_errors = [] for calculator in self.runtime_info.pipeline.calculators: self._catch_errors += [error for error in getattr(calculator, '_likelihood_catch_errors', [])] toret = self._catch_errors = tuple(self._catch_errors) return toret def _marginalize_precision(self): pipeline = self.runtime_info.pipeline all_params = pipeline._params prec_params, solved_params = [], [] for param in all_params: solved = param.derived if param.solved: if solved.startswith('.prec'): prec_params.append(param) else: solved_params.append(param) if not (solved.startswith('.auto') or solved.startswith('.marg') or solved.startswith('.best')): raise ValueError('unknown option for solved = {}'.format(solved)) # Reset precision and flatdata for likelihood in getattr(self, 'likelihoods', [self]): for name in ['precision', 'flatdata']: input_name = '_{}_input'.format(name) if hasattr(likelihood, input_name): setattr(likelihood, name, getattr(likelihood, input_name)) if prec_params: fisher = FastFisher(self, prec_params) if len(fisher.ilikelihood_solved_indices) > 1: intersection = set.intersection(*[set(indices) for indices in fisher.ilikelihood_solved_indices]) if intersection: raise ValueError('cannot use .prec for parameters that are shared between likelihoods (we would need to create a joint covariance matrix!); common block is {}'.format([fisher.solved_params[i] for i in intersection])) # Just to reject from ``values`` parameters from which base ones are derived, and are not kept in solve_likelihood.all_params values = {param.name: pipeline.input_values[param.name] for param in fisher.input_params} #print(values) for param in prec_params: values[param.name] = 0. fisher(values, gradient=False) # to set flatdata for likelihood in fisher.solve_likelihoods: likelihood._precision_input = getattr(likelihood, '_precision_input', likelihood.precision) likelihood._flatdata_input = getattr(likelihood, '_flatdata_input', likelihood.flatdata) x, dx, posterior_hessian, prior_hessian, likelihoods_hessian, likelihoods_gradient, likelihoods_flatdiff, likelihoods_flatderiv = fisher(values) for param in prec_params: values[param.name] = getattr(param.prior, 'loc', 0.) fisher(values, gradient=False) for param in prec_params: values[param.name] = getattr(param.prior, 'loc', 0.) for likelihood, flatdiff, flatderiv, solved_indices in zip(fisher.solve_likelihoods, likelihoods_flatdiff, likelihoods_flatderiv, fisher.ilikelihood_solved_indices): precision = likelihood._precision_input if precision.ndim == 1: derivp = flatderiv * precision else: derivp = flatderiv.dot(precision) likelihood.precision = np.asarray(precision - derivp.T.dot(np.linalg.solve(- posterior_hessian[np.ix_(solved_indices, solved_indices)], derivp))) likelihood.flatdata = np.asarray(likelihood._flatdata_input - (likelihood.flatdiff - flatdiff)) # flatdiff = flattheory - flatdata self.__solved_params = ParameterCollection(solved_params) self.__fisher = None def _solve(self): # Analytic marginalization, to be called, if desired, in get() pipeline = self.runtime_info.pipeline self.logprior = pipeline.params.prior(**pipeline.input_values) # does not include solved params fisher = None if self.__solved_params: derived = pipeline.derived #pipeline.more_calculate = lambda: None fisher = self.__fisher if fisher is None or fisher.mpicomm is not self.mpicomm or fisher.solved_default is not self.solved_default: #if self.fisher is not None: print(self.fisher.mpicomm is not self.mpicomm, self.fisher.varied_params != solved_params) fisher = FastFisher(self, self.__solved_params) fisher.mpicomm = self.mpicomm fisher.solved_default = self.solved_default marg_indices = [] for iparam, param in enumerate(fisher.solved_params): solved = param.derived if param.solved and not solved.startswith('.prec'): if solved.startswith('.auto'): solved = solved.replace('.auto', self.solved_default) if solved.startswith('.marg'): # marg marg_indices.append(iparam) fisher.marg_indices = np.array(marg_indices) derivs = [()] derivs_indices = [], [] for iparam1, param1 in enumerate(fisher.solved_params): if param1.derived.endswith('not_derived'): continue # do not export to .derived for iparam2, param2 in enumerate(fisher.solved_params[iparam1:]): if param2.derived.endswith('not_derived'): continue derivs.append((param1.name, param2.name)) derivs_indices[0].append(iparam1) derivs_indices[1].append(iparam1 + iparam2) fisher.derivs = derivs fisher.derivs_indices = derivs_indices self.__fisher = fisher values = {param.name: pipeline.input_values[param.name] for param in fisher.input_params} x, dx, posterior_hessian, prior_hessian, likelihoods_hessian, likelihoods_gradient, likelihoods_flatdiff, likelihoods_flatderiv = fisher(values) derived = pipeline.derived sum_loglikelihood = jnp.zeros(len(fisher.derivs) if self.__solved_params and derived is not None else (), dtype='f8') sum_logprior = jnp.zeros((), dtype='f8') if fisher is not None: for param, xx in zip(self.__solved_params, x): sum_logprior += param.prior(xx) # hack to run faster than calling param.prior --- saving ~ 0.0005 s #sum_logprior += -0.5 * (xx - param.prior.attrs['loc'])**2 / param.prior.attrs['scale']**2 if param.prior.dist == 'norm' else 0. #pipeline.i print(vlikelihood(profiles.bestfit.to_dict(params=profiles.bestfit.params(input=True))))nput_values[param.name] = xx # may lead to instabilities if derived is not None: derived.set(ParameterArray(xx, param=param)) if fisher is not None and derived is not None: sum_logprior = jnp.insert(prior_hessian[fisher.derivs_indices], 0, sum_logprior + self.logprior) else: sum_logprior += self.logprior for likelihood in getattr(self, 'likelihoods', [self]): loglikelihood = jnp.array(likelihood.loglikelihood) if fisher is not None and likelihood in fisher.solve_likelihoods: index_likelihood = fisher.solve_likelihoods.index(likelihood) ddx = dx[fisher.ilikelihood_solved_indices[index_likelihood]] likelihood_hessian = likelihoods_hessian[index_likelihood] # Here we plug in best x into L = dx.T.dot(likelihood_hessian).dot(dx) + dx.T.dot(likelihood_gradient) + likelihood_with_x_fixed # Note: priors of solved params have already been added loglikelihood += 1. / 2. * ddx.dot(likelihood_hessian).dot(ddx) loglikelihood += likelihoods_gradient[index_likelihood].dot(ddx) # Set derived values if derived is not None: loglikelihood = jnp.insert(likelihood_hessian[fisher.derivs_indices], 0, loglikelihood) derived.set(ParameterArray(loglikelihood, param=likelihood._param_loglikelihood, derivs=fisher.derivs)) sum_loglikelihood += loglikelihood if fisher is not None and fisher.marg_indices.size: marg_likelihood = -1. / 2. * jnp.linalg.slogdet(- posterior_hessian[np.ix_(fisher.marg_indices, fisher.marg_indices)])[1] # sum_loglikelihood += 1. / 2. * len(marg_indices) * np.log(2. * np.pi) # Convention: in the limit of no likelihood constraint on dx, no change to the loglikelihood # This allows to ~ keep the interpretation in terms of -1. / 2. * chi2 #ip = jnp.diag(prior_hessian)[fisher.marg_indices] #marg_likelihood += 1. / 2. * jnp.sum(jnp.log(jnp.where(ip < 0, -ip, 1.))) # logdet # sum_loglikelihood -= 1. / 2. * len(marg_indices) * np.log(2. * np.pi) if derived is not None: marg_likelihood = marg_likelihood * np.array([1.] + [0.] * (len(fisher.derivs) - 1), dtype='f8') sum_loglikelihood += marg_likelihood self.loglikelihood = sum_loglikelihood self.logprior = sum_logprior if fisher is not None and derived is not None: derived.set(ParameterArray(self.loglikelihood, param=self._param_loglikelihood, derivs=fisher.derivs)) derived.set(ParameterArray(self.logprior, param=self._param_logprior, derivs=fisher.derivs)) return self.loglikelihood.ravel()[0] + self.logprior.ravel()[0]
[docs] @classmethod def sum(cls, *others): """Sum likelihoods: return :class:`SumLikelihood` instance.""" if len(others) == 1 and utils.is_sequence(others[0]): others = others[0] likelihoods = [] for likelihood in others: if isinstance(likelihood, SumLikelihood): if likelihood.runtime_info.initialized: likelihoods += likelihood.likelihoods else: likelihoods += list(likelihood.init.get('likelihoods', [])) else: likelihoods.append(likelihood) return SumLikelihood(likelihoods=likelihoods)
def __add__(self, other): """Sum likelihoods ``self`` and ``other``: return :class:`SumLikelihood` instance.""" return self.sum(self, other) def __radd__(self, other): if other == 0: return self.sum(self) return self.__add__(other) def __iadd__(self, other): if other == 0: return self.sum(self) return self.__add__(other) @property def size(self): # Data vector size return len(self.flatdata) @property def nvaried(self): return len(self.varied_params) + len(self.all_params.select(solved=True)) @property def ndof(self): return self.size - self.nvaried def __getstate__(self, varied=True, fixed=True): state = {} for name in (['loglikelihood'] if varied else []): if hasattr(self, name): state[name] = getattr(self, name) return state
[docs] class BaseGaussianLikelihood(BaseLikelihood): """ Base class for Gaussian likelihood, which allows parameters the theory is linear with to be analytically marginalized over. Parameters ---------- data : array Data. covariance : array, default=None Covariance matrix (or its diagonal). precision : array, default=None If ``covariance`` is not provided, precision matrix (or its diagonal). """ _attrs = ['loglikelihood', 'logprior'] def initialize(self, data, covariance=None, precision=None, **kwargs): self.flatdata = np.ravel(data) if precision is None: if covariance is None: raise ValueError('Provide either precision or covariance matrix to {}'.format(self.__class__)) self.precision = utils.inv(np.atleast_2d(np.array(covariance, dtype='f8'))) else: self.precision = np.atleast_1d(np.array(precision, dtype='f8')) super(BaseGaussianLikelihood, self).initialize(**kwargs) def calculate(self): self.flatdiff = self.flattheory - self.flatdata self.loglikelihood = -0.5 * chi2(self.flatdiff, self.precision) def __getstate__(self, varied=True, fixed=True): state = {} for name in (['flatdata', 'covariance', 'precision', 'transform'] if fixed else []) + (['flatdiff', 'loglikelihood'] if varied else []): if hasattr(self, name): state[name] = getattr(self, name) return state
[docs] class ObservablesGaussianLikelihood(BaseGaussianLikelihood): """ Gaussian likelihood of observables. Parameters ---------- observables : list, BaseCalculator List of (or single) observable, e.g. :class:`TracerPowerSpectrumMultipolesObservable` or :class:`TracerCorrelationFunctionMultipolesObservable`. covariance : array, default=None Covariance matrix (or its diagonal) for input ``observables``. If ``None``, covariance matrix is computed on-the-fly using observables' mocks. scale_covariance : float, default=1. Scale precision by the inverse of this value. correct_covariance : str, default='hartlap-percival2014' Only applies if mocks are provided to input observables. 'hartlap' to apply Hartlap 2007 factor (https://arxiv.org/abs/astro-ph/0608064). 'percival2014' to apply Percival 2014 factor (https://arxiv.org/abs/1312.4841). Can be a dictionary to specify the number of observations, ``{'nobs': nobs, 'correction': 'hartlap-percival2014'}``. precision : array, default=None Precision matrix to be used instead of the inverse covariance. """ def initialize(self, observables, covariance=None, scale_covariance=1., correct_covariance='hartlap-percival2014', precision=None, **kwargs): if not utils.is_sequence(observables): observables = [observables] self.nobs = getattr(covariance, 'nobs', None) if isinstance(correct_covariance, dict): self.nobs = correct_covariance.get('nobs', self.nobs) correct_covariance = correct_covariance['correction'] self.observables = list(observables) for obs in self.observables: obs._mpicomm = self.mpicomm #for obs in observables: obs.all_params # to set observable's pipelines, and initialize once (percival factor below requires all_params) covariance, scale_covariance, precision = (self.mpicomm.bcast(obj if self.mpicomm.rank == 0 else None, root=0) for obj in (covariance, scale_covariance, precision)) if covariance is None: nmocks = [self.mpicomm.bcast(len(obs.mocks) if getattr(obs, 'mocks', None) is not None else 0) for obs in self.observables] if any(nmocks): if self.nobs is None: self.nobs = nmocks[0] if not all(nmock == nmocks[0] for nmock in nmocks): raise ValueError('Provide the same number of mocks for each observable, found {}'.format(nmocks)) if self.mpicomm.rank == 0: list_y = [np.concatenate(y, axis=0) for y in zip(*[obs.mocks for obs in self.observables])] covariance = np.cov(list_y, rowvar=False, ddof=1) covariance = self.mpicomm.bcast(covariance if self.mpicomm.rank == 0 else None, root=0) elif all(getattr(obs, 'covariance', None) is not None for obs in self.observables): covariances = [obs.covariance for obs in self.observables] if self.nobs is None: nobs = [getattr(obs, 'nobs', None) for obs in self.observables] if all(nobs): self.nobs = np.mean(nobs).astype('i4') size = sum(cov.shape[0] for cov in covariances) covariance = np.zeros((size, size), dtype='f8') start = 0 for cov in covariances: stop = start + cov.shape[0] sl = slice(start, stop) covariance[sl, sl] = cov start = stop elif precision is None: raise ValueError('Observables must have mocks or their own covariance if global covariance or precision matrix not provided') self.flatdata = np.concatenate([obs.flatdata for obs in self.observables], axis=0) def check_matrix(matrix, name): if matrix is None: return None matrix = np.atleast_2d(matrix).copy() if matrix.shape != (matrix.shape[0],) * 2: raise ValueError('{} must be a square matrix, but found shape {}'.format(name, matrix.shape)) mshape = '({0}, {0})'.format(matrix.shape[0]) shape = '({0}, {0})'.format(self.flatdata.size) shape_obs = '({0}, {0})'.format(' + '.join(['{:d}'.format(obs.flatdata.size) for obs in self.observables])) if matrix.shape[0] != self.flatdata.size: raise ValueError('based on provided observables, {} expected to be a matrix of shape {} = {}, but found {}'.format(name, shape, shape_obs, mshape)) return matrix if isinstance(covariance, ObservableCovariance): warnings.warn('desilike ObservableCovariance is deprecated. Please use lsstypes CovarianceMatrix.') cov_nobservables = len(covariance.observables()) if len(self.observables) != cov_nobservables: raise ValueError('provided {:d} observables, but the covariance contains {:d}'.format(len(self.observables), cov_nobservables)) for iobs, obs in enumerate(self.observables): array = obs.to_array() x = [(edges[:-1] + edges[1:]) / 2. for edges in array.edges()] # Cut covariance matrix to input scales covariance = covariance.xmatch(observables=iobs, x=x, projs=array.projs, select_projs=True, method='mid') covariance = covariance.view() elif isinstance(covariance, types.CovarianceMatrix): covariance = covariance.at.observable.get(observables=[observable.name for observable in self.observables]) tree = covariance.observable for iobs, obs in enumerate(self.observables): observable = obs.to_lsstypes('data') print(observable) print(tree.get(observables=obs.name)) tree = tree.at(observables=obs.name).match(observable) covariance = covariance.at.observable.match(tree) covariance = covariance.value() self.precision = check_matrix(precision, 'precision') self.covariance = check_matrix(covariance, 'covariance') self.runtime_info.requires = self.observables if self.covariance is not None: self.covariance *= scale_covariance start, slices, covariances = 0, [], [] for obs in observables: stop = start + len(obs.flatdata) sl = slice(start, stop) slices.append(sl) obs.covariance = self.covariance[sl, sl] # Set each observable's (scaled) covariance (for, e.g., plots) start = stop if self.precision is None: # Block-inversion is usually more numerically stable self.precision = utils.blockinv([[self.covariance[sl1, sl2] for sl2 in slices] for sl1 in slices]) else: self.precision /= scale_covariance self.correct_covariance = correct_covariance if self.nobs is not None and 'hartlap' in self.correct_covariance: nbins = self.precision.shape[0] self.hartlap2007_factor = (self.nobs - nbins - 2.) / (self.nobs - 1.) if self.mpicomm.rank == 0: self.log_info('Covariance matrix with {:d} points built from {:d} observations.'.format(nbins, self.nobs)) self.log_info('...resulting in a Hartlap 2007 factor of {:.4f}.'.format(self.hartlap2007_factor)) self.precision *= self.hartlap2007_factor super(ObservablesGaussianLikelihood, self).initialize(self.flatdata, covariance=self.covariance, precision=self.precision, **kwargs) self.precision_hartlap2007 = self.precision.copy() def _pipeline_initialize(self, pipeline): varied_params = pipeline._params.select(varied=True, input=True) if self.nobs is not None and 'percival' in self.correct_covariance: nbins = self.precision_hartlap2007.shape[0] # eq. 8 and 18 of https://arxiv.org/pdf/1312.4841.pdf A = 2. / (self.nobs - nbins - 1.) / (self.nobs - nbins - 4.) B = (self.nobs - nbins - 2.) / (self.nobs - nbins - 1.) / (self.nobs - nbins - 4.) params = set() def callback(calculator, params): params |= set(calculator.runtime_info.params.names()) for require in calculator.runtime_info.requires: callback(require, params) #for obs in self.observables: params |= set(obs.all_params.names()) # wrong, this will reinitialize calculators once more, which will result in unreferenced calculators if created at initialize() step in the current pipeline for obs in self.observables: callback(obs, params) params = [param for param in params if param in varied_params] nparams = len(params) self.percival2014_factor = (1 + B * (nbins - nparams)) / (1 + A + B * (nparams + 1)) if self.mpicomm.rank == 0: self.log_info('Covariance matrix with {:d} points built from {:d} observations, varying {:d} parameters.'.format(nbins, self.nobs, nparams)) self.log_info('...resulting in a Percival 2014 factor of {:.4f}.'.format(self.percival2014_factor)) self.precision = self.precision_hartlap2007 / self.percival2014_factor def calculate(self): self.flatdiff = self.flattheory - self.flatdata self.loglikelihood = -0.5 * chi2(self.flatdiff, self.precision) @property def flattheory(self): return jnp.concatenate([obs.flattheory for obs in self.observables], axis=0) def to_lsstypes(self, kind='covariance'): observables = [observable.to_lsstypes('data') for observable in self.observables] tree = types.ObservableTree(observables, observables=[observable.name for observable in self.observables]) return types.CovarianceMatrix(value=self.covariance, observable=tree) def to_covariance(self): warnings.warn('desilike ObservableCovariance is deprecated. Please use lsstypes CovarianceMatrix.') from desilike.observables import ObservableCovariance return ObservableCovariance(value=self.covariance, observables=[observable.to_array() for observable in self.observables])
[docs] @plotting.plotter def plot_covariance_matrix(self, corrcoef=True, **kwargs): """ Plot covariance matrix. Parameters ---------- corrcoef : bool, default=True If ``True``, plot the correlation matrix; else the covariance. barlabel : str, default=None Optionally, label for the color bar. label1 : str, list of str, default=None Optionally, label(s) for the observable(s). figsize : int, tuple, default=None Optionally, figure size. norm : matplotlib.colors.Normalize, default=None Scales the covariance / correlation to the canonical colormap range [0, 1] for mapping to colors. By default, the covariance / correlation range is mapped to the color bar range using linear scaling. labelsize : int, default=None Optionally, size for labels. fig : matplotlib.figure.Figure, default=None Optionally, a figure with at least ``len(self.observables) * len(self.observables)`` axes. Returns ------- fig : matplotlib.figure.Figure """ from desilike.observables.plotting import plot_covariance_matrix cumsize = np.insert(np.cumsum([len(obs.flatdata) for obs in self.observables]), 0, 0) mat = [[self.covariance[start1:stop1, start2:stop2] for start2, stop2 in zip(cumsize[:-1], cumsize[1:])] for start1, stop1 in zip(cumsize[:-1], cumsize[1:])] return plot_covariance_matrix(mat, corrcoef=corrcoef, **kwargs)
[docs] class SumLikelihood(BaseLikelihood): _attrs = ['loglikelihood', 'logprior'] def initialize(self, likelihoods, **kwargs): if not utils.is_sequence(likelihoods): likelihoods = [likelihoods] self.likelihoods = list(likelihoods) super(SumLikelihood, self).initialize(**kwargs) self.runtime_info.requires = self.likelihoods def calculate(self): # more_calculate = solve doesn't apply to ``self.likelihoods``. self.loglikelihood = sum(likelihood.loglikelihood for likelihood in self.likelihoods) @property def size(self): # Theory vector size return sum(likelihood.size for likelihood in self.likelihoods)