import re
import numpy as np
from scipy import interpolate
from desilike.jax import numpy as jnp
from desilike.jax import jit, interp1d
from desilike import jax
from desilike import plotting, utils, BaseCalculator
from .base import BaseTheoryPowerSpectrumMultipolesFromWedges
from .base import BaseTheoryPowerSpectrumMultipoles, BaseTheoryCorrelationFunctionMultipoles, BaseTheoryCorrelationFunctionFromPowerSpectrumMultipoles
from .power_template import DirectPowerSpectrumTemplate, StandardPowerSpectrumTemplate
[docs]
class BasePTPowerSpectrumMultipoles(BaseTheoryPowerSpectrumMultipoles):
"""Base class for perturbation theory matter power spectrum multipoles."""
_default_options = dict()
_klim = (1e-3, 1., 500) # klim < 1e-3 h/Mpc causes problems in velocileptors and folps when Omega_k ~ 0.1
def initialize(self, *args, template=None, z=None, **kwargs):
self.options = self._default_options.copy()
for name, value in self._default_options.items():
self.options[name] = kwargs.pop(name, value)
super(BasePTPowerSpectrumMultipoles, self).initialize(*args, **kwargs)
if template is None:
template = DirectPowerSpectrumTemplate()
self.template = template
kin = np.geomspace(min(self._klim[0], self.k[0] / 2, self.template.init.get('k', [1.])[0]), max(self._klim[1], self.k[-1] * 2, self.template.init.get('k', [0.])[0]), self._klim[2]) # margin for AP effect
self.template.init.update(k=kin)
if z is not None: self.template.init.update(z=z)
self.z = self.template.z
def calculate(self):
self.z = self.template.z
[docs]
class BasePTCorrelationFunctionMultipoles(BaseTheoryCorrelationFunctionMultipoles):
_default_options = dict()
_klim = (1e-3, 1., 500)
def initialize(self, *args, template=None, **kwargs):
self.options = self._default_options.copy()
for name, value in self._default_options.items():
self.options[name] = kwargs.pop(name, value)
super(BasePTCorrelationFunctionMultipoles, self).initialize(*args, **kwargs)
if template is None:
template = DirectPowerSpectrumTemplate()
self.template = template
kin = np.geomspace(min(self._klim[0], 1 / self.s[-1] / 2, self.template.init.get('k', [1.])[0]), max(self._klim[1], 1 / self.s[0] * 2, self.template.init.get('k', [0.])[0]), self._klim[2]) # margin for AP effect
self.template.init.update(k=kin)
self.z = self.template.z
def calculate(self):
self.z = self.template.z
[docs]
class BaseTracerTheory(BaseCalculator):
_initialize_with_namespace = True
_calculate_with_namespace = True # multi-tracer
_deterministic_bias_params = []
_stochastic_bias_params = []
_with_cross = False # whether cross-correlation has been implemented
def initialize(self, tracers=None):
self.tracers = tracers
# runtime bias parameters
self.deterministic_bias_params = list(self._deterministic_bias_params)
self.stochastic_bias_params = list(self._stochastic_bias_params)
@classmethod
def _params(cls, params, tracers=None):
if not tracers:
return params
*tracer_namespaces, cross_namespace = cls.multitracer_namespace(tracers)
for name in cls._deterministic_bias_params:
param = params.pop(name, None)
if param is not None: # only if parameter exists (in case of a preselection)
for namespace in tracer_namespaces:
params.set(param.clone(namespace=namespace))
for name in cls._stochastic_bias_params:
param = params.get(name, None)
if param is not None:
param.update(namespace=cross_namespace)
return params
@classmethod
def multitracer_namespace(cls, tracers, tail=''):
tracers = tracers or []
if isinstance(tracers, str):
tracers = [tracers]
nnames, ntracers = len(tracers), cls._ntracers
if nnames in (ntracers, ntracers + 1) and not cls._with_cross:
raise NotImplementedError(f'{cls} has not implemented cross-correlation yet')
if nnames == 0:
# default auto correlation
*tracer_namespaces, cross_namespace = [''] * (ntracers + 1)
elif nnames == 1:
# auto correlation
*tracer_namespaces, cross_namespace = [tracers[0]] * (ntracers + 1)
elif nnames == ntracers:
# cross correlation
tracer_namespaces, cross_namespace = tracers, 'x'.join(tracers)
elif nnames == ntracers + 1:
# cross correlation, with given cross namespace
*tracer_namespaces, cross_namespace = tracers
else:
raise ValueError(f"`tracers` should be a string or a list of maximum {ntracers+1} names ({ntracers} auto and 1 cross)")
return tuple(_ + tail if _ else '' for _ in list(tracer_namespaces) + [cross_namespace])
def pack_input_bias_params(self, params, defaults=None):
if defaults is None:
defaults = self.required_bias_params | self.optional_bias_params
*tracer_namespaces, cross_namespace = self.multitracer_namespace(self.tracers, tail='.')
if self.tracers:
toret = {}
for param in self.deterministic_bias_params:
toret[param] = tuple(params.get(f'{namespace}{param}', defaults[param]) for namespace in tracer_namespaces)
for param in self.stochastic_bias_params:
toret[param] = params.get(f'{cross_namespace}{param}', defaults[param])
else: # fallback to standard, where the user may have provided a namespace
toret = defaults | {name.split('.')[-1]: value for name, value in params.items()}
for param in self.deterministic_bias_params:
toret[param] = (toret[param],) * len(tracer_namespaces)
return toret
def is_cross_correlation(self):
if not self.tracers or isinstance(self.tracers, str):
return False
return True
[docs]
class BaseTracerTwoPointTheory(BaseTracerTheory):
_ntracers = 2
[docs]
class BaseTracerThreePointTheory(BaseTracerTheory):
_ntracers = 3
[docs]
class BaseTracerPowerSpectrumMultipoles(BaseTracerTwoPointTheory):
"""Base class for perturbation theory tracer power spectrum multipoles."""
config_fn = 'full_shape.yaml'
_initialize_with_namespace = True # to properly forward parameters to pt
_default_options = dict()
def initialize(self, pt=None, template=None, **kwargs):
super(BaseTracerPowerSpectrumMultipoles, self).initialize(tracers=kwargs.pop('tracers', None))
self.options = self._default_options.copy()
shotnoise = kwargs.get('shotnoise', 1e4)
if utils.is_sequence(shotnoise):
# cross correlation: geometric mean
shotnoise = np.sqrt(np.prod(shotnoise))
for name, value in self._default_options.items():
self.options[name] = kwargs.pop(name, value)
if 'shotnoise' in self.options:
self.options['shotnoise'] = shotnoise
self.nd = 1. / float(shotnoise)
if pt is None:
pt = globals()[getattr(self, 'pt_cls', self.__class__.__name__.replace('Tracer', ''))]()
self.pt = pt
if template is not None:
self.pt.init.update(template=template)
for name, value in self.pt._default_options.items():
if name in kwargs:
self.pt.init.update({name: kwargs.pop(name)})
elif name in self.options:
self.pt.init.update({name: self.options[name]})
for name in ['method', 'mu']:
if name in kwargs:
self.pt.init.update({name: kwargs.pop(name)})
self.required_bias_params, self.optional_bias_params = {}, {}
self.pt.init.update(kwargs)
for name in ['z', 'k', 'ells']:
setattr(self, name, getattr(self.pt, name))
self.set_params()
def set_params(self, pt_params=None):
all_bias_params = list(self.required_bias_params.keys()) + list(self.optional_bias_params.keys())
if pt_params is None:
for param in self.init.params:
if param.basename not in all_bias_params and not (param.derived is True):
pt_params.append(param.basename)
self.pt.init.params.update([param for param in self.init.params if param.basename in pt_params], basename=True)
self.init.params = self.init.params.select(basename=[param.basename for param in self.init.params if param.basename in all_bias_params or (param.derived is True)])
def calculate(self):
for name in ['z', 'k', 'ells']:
setattr(self, name, getattr(self.pt, name))
[docs]
def get(self):
return self.power
@property
def template(self):
return self.pt.template
def __getstate__(self):
state = {}
for name in ['k', 'z', 'ells', 'nd', 'power']:
if hasattr(self, name):
state[name] = getattr(self, name)
return state
[docs]
@plotting.plotter
def plot(self, fig=None):
"""
Plot power spectrum multipoles.
Parameters
----------
fig : matplotlib.figure.Figure, default=None
Optionally, a figure with at least 1 axis.
fn : str, Path, default=None
Optionally, path where to save figure.
If not provided, figure is not saved.
kw_save : dict, default=None
Optionally, arguments for :meth:`matplotlib.figure.Figure.savefig`.
show : bool, default=False
If ``True``, show figure.
Returns
-------
fig : matplotlib.figure.Figure
"""
from matplotlib import pyplot as plt
if fig is None:
fig, ax = plt.subplots()
else:
ax = fig.axes[0]
for ill, ell in enumerate(self.ells):
ax.plot(self.k, self.k * self.power[ill], color='C{:d}'.format(ill), linestyle='-', label=r'$\ell = {:d}$'.format(ell))
ax.grid(True)
ax.legend()
ax.set_ylabel(r'$k P_{\ell}(k)$ [$(\mathrm{Mpc}/h)^{2}$]')
ax.set_xlabel(r'$k$ [$h/\mathrm{Mpc}$]')
return fig
[docs]
class BaseTracerCorrelationFunctionMultipoles(BaseTracerTwoPointTheory):
"""Base class for perturbation theory tracer correlation function multipoles."""
config_fn = 'full_shape.yaml'
_initialize_with_namespace = True # to properly forward parameters to pt
_default_options = dict()
def initialize(self, pt=None, template=None, **kwargs):
super(BaseTracerCorrelationFunctionMultipoles, self).initialize(tracers=kwargs.pop('tracers', None))
self.options = self._default_options.copy()
for name, value in self._default_options.items():
self.options[name] = kwargs.pop(name, value)
if pt is None:
pt = globals()[getattr(self, 'pt_cls', self.__class__.__name__.replace('Tracer', ''))]()
self.pt = pt
if template is not None:
self.pt.init.update(template=template)
for name, value in self.pt._default_options.items():
if name in kwargs:
self.pt.init.update({name: kwargs.pop(name)})
elif name in self.options:
self.pt.init.update({name: self.options[name]})
self.required_bias_params, self.optional_bias_params = dict.fromkeys(self.init.params.basenames()), {}
self.pt.init.update(kwargs)
for name in ['z', 's', 'ells']:
setattr(self, name, getattr(self.pt, name))
self.set_params()
def set_params(self, pt_params=None):
all_bias_params = list(self.required_bias_params) + list(self.optional_bias_params)
if pt_params is None:
for param in self.init.params:
if param.basename not in all_bias_params and not (param.derived is True):
pt_params.append(param.basename)
self.pt.init.params.update([param for param in self.init.params if param.basename in pt_params], basename=True)
self.init.params = self.init.params.select(basename=[param.basename for param in self.init.params if param.basename in all_bias_params or (param.derived is True)])
def calculate(self):
for name in ['z', 's', 'ells']:
setattr(self, name, getattr(self.pt, name))
[docs]
def get(self):
return self.corr
@property
def template(self):
return self.pt.template
def __getstate__(self):
state = {}
for name in ['s', 'z', 'ells', 'corr']:
if hasattr(self, name):
state[name] = getattr(self, name)
return state
[docs]
@plotting.plotter
def plot(self, fig=None):
"""
Plot correlation function multipoles.
Parameters
----------
fig : matplotlib.figure.Figure, default=None
Optionally, a figure with at least 1 axis.
fn : str, Path, default=None
Optionally, path where to save figure.
If not provided, figure is not saved.
kw_save : dict, default=None
Optionally, arguments for :meth:`matplotlib.figure.Figure.savefig`.
show : bool, default=False
If ``True``, show figure.
"""
from matplotlib import pyplot as plt
if fig is None:
fig, ax = plt.subplots()
else:
ax = fig.axes[0]
for ill, ell in enumerate(self.ells):
ax.plot(self.s, self.s**2 * self.corr[ill], color='C{:d}'.format(ill), linestyle='-', label=r'$\ell = {:d}$'.format(ell))
ax.grid(True)
ax.legend()
ax.set_ylabel(r'$s^{2} \xi_{\ell}(s)$ [$(\mathrm{Mpc}/h)^{2}$]')
ax.set_xlabel(r'$s$ [$\mathrm{Mpc}/h$]')
return fig
[docs]
class BaseTracerCorrelationFunctionFromPowerSpectrumMultipoles(BaseTheoryCorrelationFunctionFromPowerSpectrumMultipoles, BaseTracerTwoPointTheory):
"""Base class for perturbation theory tracer correlation function multipoles as Hankel transforms of the power spectrum multipoles."""
config_fn = 'full_shape.yaml'
def initialize(self, *args, pt=None, template=None, **kwargs):
BaseTracerTwoPointTheory.initialize(self, tracers=kwargs.pop('tracers', None))
power = globals()[self.__class__.__name__.replace('CorrelationFunction', 'PowerSpectrum')]()
if pt is not None: power.init.update(pt=pt)
if template is not None: power.init.update(template=template)
super(BaseTracerCorrelationFunctionFromPowerSpectrumMultipoles, self).initialize(*args, power=power, **kwargs)
for name in ['z', 'ells']:
setattr(self, name, getattr(self.power, name))
def calculate(self):
for name in ['z', 'ells']:
setattr(self, name, getattr(self.power, name))
super(BaseTracerCorrelationFunctionFromPowerSpectrumMultipoles, self).calculate()
@property
def pt(self):
return self.power.pt
@property
def template(self):
return self.power.template
[docs]
def get(self):
return self.corr
[docs]
class SimpleTracerPowerSpectrumMultipoles(BasePTPowerSpectrumMultipoles, BaseTheoryPowerSpectrumMultipolesFromWedges, BaseTracerTwoPointTheory):
r"""
Kaiser tracer power spectrum multipoles, with fixed damping, essentially used for Fisher forecasts.
For the matter (unbiased) power spectrum, set b1=1 and sn0=0.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
mu : int, default=8
Number of :math:`\mu`-bins to use (in :math:`[0, 1]`).
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`StandardPowerSpectrumTemplate`.
shotnoise : float, default=1e4
Shot noise (which is usually marginalized over).
"""
config_fn = 'full_shape.yaml'
_deterministic_bias_params = ['b1']
_stochastic_bias_params = ['sn0']
_with_cross = True
def initialize(self, *args, mu=8, method='leggauss', template=None, shotnoise=1e4, **kwargs):
BaseTracerTwoPointTheory.initialize(self, tracers=kwargs.pop('tracers', None))
if utils.is_sequence(shotnoise):
# cross correlation
shotnoise = np.sqrt(np.prod(shotnoise))
self.nd = 1. / float(shotnoise)
if template is None:
template = StandardPowerSpectrumTemplate()
super(SimpleTracerPowerSpectrumMultipoles, self).initialize(*args, template=template, mu=mu, method=method, **kwargs)
def calculate(self, sigmapar=0., sigmaper=0., **kwargs):
super(SimpleTracerPowerSpectrumMultipoles, self).calculate()
bias_params = self.pack_input_bias_params(kwargs, defaults=dict(b1=1., sn0=0.))
(b1X, b1Y), sn0 = bias_params['b1'], bias_params['sn0']
jac, kap, muap = self.template.ap_k_mu(self.k, self.mu)
f = self.template.f
sigmanl2 = self.k[:, None]**2 * (sigmapar**2 * self.mu**2 + sigmaper**2 * (1. - self.mu**2))
damping = jnp.exp(-sigmanl2 / 2.)
#pkmu = jac * damping * (b1X + f * muap**2) * (b1Y + f * muap**2) * jnp.interp(jnp.log10(kap), jnp.log10(self.template.k), self.template.pk_dd) + sn0 / self.nd
pkmu = jac * damping * (b1X + f * muap**2) * (b1Y + f * muap**2) * interp1d(jnp.log10(kap), jnp.log10(self.template.k), self.template.pk_dd, method='cubic') + sn0 / self.nd
self.power = self.to_poles(pkmu)
[docs]
def get(self):
return self.power
def __getstate__(self):
state = {}
for name in ['k', 'z', 'ells', 'nd', 'power']:
if hasattr(self, name):
state[name] = getattr(self, name)
return state
[docs]
@plotting.plotter
def plot(self, fig=None):
"""
Plot power spectrum multipoles.
Parameters
----------
fig : matplotlib.figure.Figure, default=None
Optionally, a figure with at least 1 axis.
fn : str, Path, default=None
Optionally, path where to save figure.
If not provided, figure is not saved.
kw_save : dict, default=None
Optionally, arguments for :meth:`matplotlib.figure.Figure.savefig`.
show : bool, default=False
If ``True``, show figure.
Returns
-------
fig : matplotlib.figure.Figure
"""
from matplotlib import pyplot as plt
if fig is None:
fig, ax = plt.subplots()
else:
ax = fig.axes[0]
for ill, ell in enumerate(self.ells):
ax.plot(self.k, self.k * self.power[ill], color='C{:d}'.format(ill), linestyle='-', label=r'$\ell = {:d}$'.format(ell))
ax.grid(True)
ax.legend()
ax.set_ylabel(r'$k P_{\ell}(k)$ [$(\mathrm{Mpc}/h)^{2}$]')
ax.set_xlabel(r'$k$ [$h/\mathrm{Mpc}$]')
return fig
[docs]
class KaiserPowerSpectrumMultipoles(BasePTPowerSpectrumMultipoles, BaseTheoryPowerSpectrumMultipolesFromWedges):
r"""
Kaiser power spectrum multipoles.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
mu : int, default=8
Number of :math:`\mu`-bins to use (in :math:`[0, 1]`).
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
"""
_params = {'sigmapar': {'value': 0., 'fixed': True}, 'sigmaper': {'value': 0, 'fixed': True}}
def initialize(self, *args, mu=8, **kwargs):
super(KaiserPowerSpectrumMultipoles, self).initialize(*args, mu=mu, method='leggauss', **kwargs)
#self.template.init.update(k=np.logspace(-4, 2, 1000))
def calculate(self, sigmapar=0., sigmaper=0.):
super(KaiserPowerSpectrumMultipoles, self).calculate()
jac, kap, muap = self.template.ap_k_mu(self.k, self.mu)
f = self.template.f
sigmanl2 = kap**2 * (sigmapar**2 * muap**2 + sigmaper**2 * (1. - muap**2))
damping = jnp.exp(-sigmanl2 / 2.)
self.pktable = []
self.k11 = self.template.k
self.pk11 = self.template.pk_dd
pktable = jac * damping * interp1d(jnp.log10(kap), jnp.log10(self.k11), self.pk11, method='cubic')
self.pktable = {'pk_dd': self.to_poles(pktable), 'pk_dt': self.to_poles(f * muap**2 * pktable), 'pk_tt': self.to_poles(f**2 * muap**4 * pktable)}
self.pktable['pk11'] = self.pktable['pk_dd']
def __getstate__(self):
state = {}
for name in ['k', 'z', 'ells']:
if hasattr(self, name):
state[name] = getattr(self, name)
for name in self.pktable:
state[name] = self.pktable[name]
state['names'] = list(self.pktable.keys())
return state
def __setstate__(self, state):
state = dict(state)
self.pktable = {name: state.pop(name, None) for name in state['names']}
super(KaiserPowerSpectrumMultipoles, self).__setstate__(state)
[docs]
class KaiserTracerPowerSpectrumMultipoles(BaseTracerPowerSpectrumMultipoles):
r"""
Kaiser tracer power spectrum multipoles.
For the matter (unbiased) power spectrum, set b1=1 and sn0=0.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
mu : int, default=8
Number of :math:`\mu`-bins to use (in :math:`[0, 1]`).
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
"""
_deterministic_bias_params = ['b1']
_stochastic_bias_params = ['sn0']
_with_cross = True
def set_params(self):
self.required_bias_params.update(dict(b1=1., sn0=0.))
super().set_params(pt_params=['sigmapar', 'sigmaper'])
def calculate(self, **kwargs):
super(KaiserTracerPowerSpectrumMultipoles, self).calculate()
bias_params = self.pack_input_bias_params(kwargs)
(b1X, b1Y), sn0 = bias_params["b1"], bias_params["sn0"]
sn0 = np.array([(ell == 0) for ell in self.ells], dtype='f8')[:, None] * sn0 / self.nd
self.power = b1X * b1Y * self.pt.pktable['pk_dd'] + (b1X + b1Y) * self.pt.pktable['pk_dt'] + self.pt.pktable['pk_tt'] + sn0
[docs]
class KaiserTracerCorrelationFunctionMultipoles(BaseTracerCorrelationFunctionFromPowerSpectrumMultipoles):
r"""
Kaiser tracer correlation function multipoles.
For the matter (unbiased) correlation function, set b1=1 and sn0=0.
Parameters
----------
s : array, default=None
Theory separations where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
**kwargs : dict
Options, defaults to: ``mu=8``.
"""
_deterministic_bias_params = KaiserTracerPowerSpectrumMultipoles._deterministic_bias_params
_stochastic_bias_params = []
_with_cross = True
[docs]
class BaseEFTLikeTracerPowerSpectrumMultipoles(object):
r"""
Base class for tracer power spectrum multipoles with EFT-like counter and stochastic terms.
Can be exactly marginalized over counter terms and stochastic parameters ct*, sn*.
"""
def initialize(self, *args, **kwargs):
self.pt_cls = self.__class__.__name__.replace('EFTLike', '').replace('Tracer', '')
super(BaseEFTLikeTracerPowerSpectrumMultipoles, self).initialize(*args, **kwargs)
def set_params(self):
self.kp = 1.
def get_params_matrix(base):
coeffs = {ell: {} for ell in self.ells}
for param in self.init.params.select(basename=base + '*_*'):
name = param.basename
match = re.match(base + '(.*)_(.*)', name)
if match:
ell, pow = int(match.group(1)), int(match.group(2))
if ell in self.ells:
coeffs[ell][name] = (self.k / self.kp)**pow
else:
del self.init.params[param]
for param in self.init.params.select(basename=base + '0'):
ell, name = 0, param.basename
if ell in self.ells:
if name + '_0' in coeffs[ell]:
raise ValueError('Choose between {} and {}'.format(name, name + '_0'))
coeffs[ell][name] = 1.
else:
del self.init.params[param]
params = [name for ell in self.ells for name in coeffs[ell]]
if not params:
return params, jnp.array([], dtype='f8')
matrix = []
for ell in self.ells:
row = [np.zeros_like(self.k) for i in range(len(params))]
for name, k_i in coeffs[ell].items():
row[params.index(name)][:] = k_i
matrix.append(np.column_stack(row))
matrix = jnp.array(matrix)
return params, matrix
self.counterterm_params, self.counterterm_matrix = get_params_matrix('ct')
self.stochastic_params, self.stochastic_matrix = get_params_matrix('sn')
params = self.counterterm_params + self.stochastic_params
self.required_bias_params = dict(**self.required_bias_params, **dict(zip(params, [0] * len(params))))
super().set_params()
self.deterministic_bias_params = [param for param in self.deterministic_bias_params if param in self.required_bias_params]
self.stochastic_bias_params = [param for param in self.stochastic_bias_params if param in self.required_bias_params]
def calculate(self, **params):
bias_params = self.pack_input_bias_params(params)
counterterm_values = 0.5 * jnp.array([sum(bias_params[name]) for name in self.counterterm_params])
stochastic_values = jnp.array([bias_params[name] for name in self.stochastic_params]) / self.nd
super(BaseEFTLikeTracerPowerSpectrumMultipoles, self).calculate(**params)
self.power += self.counterterm_matrix.dot(counterterm_values) * self.pt.pktable['pk11'][self.pt.ells.index(0)]
self.power += self.stochastic_matrix.dot(stochastic_values)
[docs]
class EFTLikeKaiserTracerPowerSpectrumMultipoles(BaseEFTLikeTracerPowerSpectrumMultipoles, KaiserTracerPowerSpectrumMultipoles):
r"""
Kaiser tracer power spectrum multipoles with EFT-like counter and stochastic terms.
Can be exactly marginalized over counter terms and stochastic parameters ct*, sn*.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
mu : int, default=8
Number of :math:`\mu`-bins to use (in :math:`[0, 1]`).
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
shotnoise : float, default=1e4
Shot noise (which is usually marginalized over).
"""
_deterministic_bias_params = KaiserTracerPowerSpectrumMultipoles._deterministic_bias_params + ['ct0_2', 'ct2_2', 'ct4_2']
_stochastic_bias_params = KaiserTracerPowerSpectrumMultipoles._stochastic_bias_params + ['sn0_2', 'sn2_2', 'sn4_2']
_with_cross = True
[docs]
class EFTLikeKaiserTracerCorrelationFunctionMultipoles(BaseTracerCorrelationFunctionFromPowerSpectrumMultipoles):
r"""
EFT-like Kaiser tracer correlation function multipoles.
Can be exactly marginalized over counter terms and stochastic parameters ct*, sn*.
Parameters
----------
s : array, default=None
Theory separations where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
**kwargs : dict
Options, defaults to: ``mu=8``.
"""
_deterministic_bias_params = EFTLikeKaiserTracerPowerSpectrumMultipoles._deterministic_bias_params
_stochastic_bias_params = []
_with_cross = True
def tns_kernels(k, q, wq):
jq = q**2 * wq / (4. * np.pi**2)
k = k[:, None]
x = q / k
kernels = [None] * 3
# Integral of F3(q, -q, k) over mu cosine angle between k and q
def kernel_ff(x):
x = np.array(x)
toret = (6. / x**2 - 79. + 50. * x**2 - 21. * x**4 + 0.75 * (1. / x - x)**3 * (2. + 7. * x**2) * 2 * np.log(np.abs((x - 1.) / (x + 1.)))) / 504.
mask = x > 10.
toret[mask] = - 61. / 630. + 2. / 105. / x[mask]**2 - 10. / 1323. / x[mask]**4
dx = x - 1.
mask = np.abs(dx) < 0.01
toret[mask] = - 11. / 126. + dx[mask] / 126. - 29. / 252. * dx[mask]**2
return toret / x**2
def kernel_gg(x):
x = np.array(x)
toret = (6. / x**2 - 41. + 2. * x**2 - 3. * x**4 + 0.75 * (1. / x - x)**3 * (2. + x**2) * 2 * np.log(np.abs((x - 1.) / (x + 1.)))) / 168.
mask = x > 10.
toret[mask] = - 3. / 10. + 26. / 245. / x[mask]**2 - 38. / 2205. / x[mask]**4
dx = x - 1.
mask = np.abs(dx) < 0.01
toret[mask] = - 3. / 14. - 5. / 42. * dx[mask] - 1. / 84. * dx[mask]**2
return toret / x**2
kernels[0] = 2 * jq * kernel_ff(x)
kernels[1] = 2 * jq * kernel_gg(x)
def kernel_a(x):
toret = np.zeros((5,) + x.shape, dtype='f8')
logx = np.zeros_like(x)
mask = np.abs(x - 1) > 1e-16
logx[mask] = np.log(np.abs((x[mask] + 1) / (x[mask] - 1)))
toret[0] = -1. / 84. / x * (2 * x * (19 - 24 * x**2 + 9 * x**4) - 9 * (x**2 - 1)**3 * logx)
toret[1] = 1. / 112. / x**3 * (2 * x * (x**2 + 1) * (3 - 14 * x**2 + 3 * x**4) - 3 * (x**2 - 1)**4 * logx)
toret[2] = 1. / 336. / x**3 * (2 * x * (9 - 185 * x**2 + 159 * x**4 - 63 * x**6) + 9 * (x**2 - 1)**3 * (7 * x**2 + 1) * logx)
toret[4] = 1. / 336. / x**3 * (2 * x * (9 - 109 * x**2 + 63 * x**4 - 27 * x**6) + 9 * (x**2 - 1)**3 * (3 * x**2 + 1) * logx)
mask = x < 1e-4
xm = x[mask]
toret[0][mask] = 8 * xm**8 / 735 + 24 * xm**6 / 245 - 24 * xm**4 / 35 + 8 * xm**2 / 7 - 2. / 3
toret[1][mask] = - 16 * xm**8 / 8085 - 16 * xm**6 / 735 + 48 * xm**4 / 245 - 16 * xm**2 / 35
toret[2][mask] = 32 * xm**8 / 1617 + 128 * xm**6 / 735 - 288 * xm**4 / 245 + 64 * xm**2 / 35 - 4. / 3
toret[4][mask] = 24 * xm**8 / 2695 + 8 * xm**6 / 105 - 24 * xm**4 / 49 + 24 * xm**2 / 35 - 2. / 3
mask = x > 1e2
xm = x[mask]
toret[0][mask] = 2. / 105 - 24 / (245 * xm**2) - 8 / (735 * xm**4) - 8 / (2695 * xm**6) - 8 / (7007 * xm**8)
toret[1][mask] = -16. / 35 + 48 / (245 * xm**2) - 16 / (735 * xm**4) - 16 / (8085 * xm**6) - 16 / (35035 * xm**8)
toret[2][mask] = -44. / 105 - 32 / (735 * xm**4) - 64 / (8085 * xm**6) - 96 / (35035 * xm**8)
toret[4][mask] = -46. / 105 + 24 / (245 * xm**2) - 8 / (245 * xm**4) - 8 / (1617 * xm**6) - 8 / (5005 * xm**8)
toret[3] = toret[1]
return toret / x**2
kernels[2] = jq * kernel_a(x)
return kernels
@jit
def tns_pt(k, q, wq, pk_q, kernel13_d, kernel13_t, kernel_a):
# We could have a speed-up with FFTlog, see https://arxiv.org/pdf/1603.04405.pdf
k11 = k
k = k[:, None]
jq = q**2 * wq / (4. * np.pi**2)
x = q / k
mus, wmus = utils.weights_mu(10, method='leggauss')
# Compute P22
pk22_dd, pk22_dt, pk22_tt = (0.,) * 3
pk_b2d, pk_bs2d, pk_b2t, pk_bs2t, sig3sq, pk_b22, pk_b2s2, pk_bs22 = (0.,) * 8
A = jnp.zeros((5,) + k11.shape, dtype='f8')
B = [jnp.zeros(k11.shape, dtype='f8') for i in range(12)]
pk_k = jnp.interp(k11, q, pk_q)
def get_terms(mu, wmu):
kdq = k * q * mu # k \cdot q
kq2 = k**2 - 2. * kdq + q**2 # |k - q|^2
qdkq = kdq - q**2 # k \cdot (k - q)
F2_d = 5. / 7. + 1. / 2. * qdkq * (1. / q**2 + 1. / kq2) + 2. / 7. * qdkq**2 / (q**2 * kq2)
F2_t = 3. / 7. + 1. / 2. * qdkq * (1. / q**2 + 1. / kq2) + 4. / 7. * qdkq**2 / (q**2 * kq2)
# https://arxiv.org/pdf/0902.0991.pdf
S = (qdkq)**2 / (q**2 * kq2) - 1. / 3.
D = 2. / 7. * (mu**2 - 1.)
pk_kq = jnp.interp(kq2**0.5, q, pk_q, left=0., right=0.)
jq_pk_q_pk_kq = jq * pk_q * pk_kq
pk_b2d = wmu * jnp.sum(jq_pk_q_pk_kq * F2_d, axis=-1)
pk_bs2d = wmu * jnp.sum(jq_pk_q_pk_kq * F2_d * S, axis=-1)
pk_b2t = wmu * jnp.sum(jq_pk_q_pk_kq * F2_t, axis=-1)
pk_bs2t = wmu * jnp.sum(jq_pk_q_pk_kq * F2_t * S, axis=-1)
sig3sq = wmu * jnp.sum(105. / 16. * jq * pk_q * (D * S + 8. / 63.), axis=-1)
pk_b22 = wmu / 2. * jnp.sum(jq * pk_q * (pk_kq - pk_q), axis=-1)
pk_b2s2 = wmu / 2. * jnp.sum(jq * pk_q * (pk_kq * S - 2. / 3. * pk_q), axis=-1)
pk_bs22 = wmu / 2. * jnp.sum(jq * pk_q * (pk_kq * S**2 - 4. / 9. * pk_q), axis=-1)
pk22_dd = 2 * wmu * jnp.sum(F2_d**2 * jq_pk_q_pk_kq, axis=-1)
pk22_dt = 2 * wmu * jnp.sum(F2_d * F2_t * jq_pk_q_pk_kq, axis=-1)
pk22_tt = 2 * wmu * jnp.sum(F2_t * F2_t * jq_pk_q_pk_kq, axis=-1)
xmu = kq2 / k**2
kernel_A, kernel_tA = [0] * 5, [0] * 5
kernel_A[0] = - x**3 / 7. * (mu + 6 * mu**3 + x**2 * mu * (-3 + 10 * mu**2) + x * (-3 + mu**2 - 12 * mu**4))
kernel_A[1] = x**4 / 14. * (mu**2 - 1) * (-1 + 7 * x * mu - 6 * mu**2)
kernel_A[2] = x**3 / 14. * (x**2 * mu * (13 - 41 * mu**2) - 4 * (mu + 6 * mu**3) + x * (5 + 9 * mu**2 + 42 * mu**4))
kernel_A[3] = kernel_A[1]
kernel_A[4] = x**3 / 14. * (1 - 7 * x * mu + 6 * mu**2) * (-2 * mu + x * (-1 + 3 * mu**2))
kernel_tA[0] = 1. / 7. * (mu + x - 2 * x * mu**2) * (3 * x + 7 * mu - 10 * x * mu**2)
kernel_tA[1] = x / 14. * (mu**2 - 1) * (3 * x + 7 * mu - 10 * x * mu**2)
kernel_tA[2] = 1. / 14. * (28 * mu**2 + x * mu * (25 - 81 * mu**2) + x**2 * (1 - 27 * mu**2 + 54 * mu**4))
kernel_tA[3] = x / 14. * (1 - mu**2) * (x - 7 * mu + 6 * x * mu**2)
kernel_tA[4] = 1. / 14. * (x - 7 * mu + 6 * x * mu**2) * (-2 * mu - x + 3 * x * mu**2)
# Taruya 2010 (arXiv 1006.0699v1) eq A3
A = wmu * jnp.sum(jq / x**2 * (jnp.array(kernel_A) * pk_k[:, None] + jnp.array(kernel_tA) * pk_q) * pk_kq / xmu**2, axis=-1)
jq_pk_q_pk_kq /= x**2 * xmu
B = [0.] * 12
B[0] = wmu * jnp.sum(x**2 * (mu**2 - 1.) / 2. * jq_pk_q_pk_kq, axis=-1) # n,a,b = 1,1,1
B[1] = wmu * jnp.sum(3. * x**2 * (mu**2 - 1.)**2 / 8. * jq_pk_q_pk_kq, axis=-1) # n,a,b = 1,1,2
B[2] = wmu * jnp.sum(3. * x**4 * (mu**2 - 1.)**2 / xmu / 8. * jq_pk_q_pk_kq, axis=-1) # n,a,b = 1,2,1
B[3] = wmu * jnp.sum(5. * x**4 * (mu**2 - 1.)**3 / xmu / 16. * jq_pk_q_pk_kq, axis=-1) # n,a,b = 1,2,2
B[4] = wmu * jnp.sum(x * (x + 2. * mu - 3. * x * mu**2) / 2. * jq_pk_q_pk_kq, axis=-1) # n,a,b = 2,1,1
B[5] = wmu * jnp.sum(- 3. * x * (mu**2 - 1.) * (-x - 2. * mu + 5. * x * mu**2) / 4. * jq_pk_q_pk_kq, axis=-1) # n,a,b = 2,1,2
B[6] = wmu * jnp.sum(3. * x**2 * (mu**2 - 1.) * (-2. + x**2 + 6. * x * mu - 5. * x**2 * mu**2) / xmu / 4. * jq_pk_q_pk_kq, axis=-1) # n,a,b = 2,2,1
B[7] = wmu * jnp.sum(- 3. * x**2 * (mu**2 - 1.)**2 * (6. - 5. * x**2 - 30. * x * mu + 35. * x**2 * mu**2) / xmu / 16. * jq_pk_q_pk_kq, axis=-1) # n,a,b = 2,2,2
B[8] = wmu * jnp.sum(x * (4. * mu * (3. - 5. * mu**2) + x * (3. - 30. * mu**2 + 35. * mu**4)) / 8. * jq_pk_q_pk_kq, axis=-1) # n,a,b = 3,1,2
B[9] = wmu * jnp.sum(x * (-8. * mu + x * (-12. + 36. * mu**2 + 12. * x * mu * (3. - 5. * mu**2) + x**2 * (3. - 30. * mu**2 + 35. * mu**4))) / xmu / 8. * jq_pk_q_pk_kq, axis=-1) # n,a,b = 3,2,1
B[10] = wmu * jnp.sum(3. * x * (mu**2 - 1.) * (-8. * mu + x * (-12. + 60. * mu**2 + 20. * x * mu * (3. - 7. * mu**2) + 5. * x**2 * (1. - 14. * mu**2 + 21. * mu**4))) / xmu / 16. * jq_pk_q_pk_kq, axis=-1) # n,a,b = 3,2,2
B[11] = wmu * jnp.sum(x * (8. * mu * (-3. + 5. * mu**2) - 6. * x * (3. - 30. * mu**2 + 35. * mu**4) + 6. * x**2 * mu * (15. - 70. * mu**2 + 63 * mu**4) + x**3 * (5. - 21. * mu**2 * (5. - 15. * mu**2 + 11. * mu**4))) / xmu / 16. * jq_pk_q_pk_kq, axis=-1) # n,a,b = 4,2,2
return jnp.stack([pk_b2d, pk_bs2d, pk_b2t, pk_bs2t, sig3sq, pk_b22, pk_b2s2, pk_bs22, pk22_dd, pk22_dt, pk22_tt] + list(A) + B)
res = jnp.sum(jax.vmap(get_terms)(mus, wmus), axis=0)
pk_b2d, pk_bs2d, pk_b2t, pk_bs2t, sig3sq, pk_b22, pk_b2s2, pk_bs22, pk22_dd, pk22_dt, pk22_tt = res[:11]
A, B = res[11:16], res[16:]
A += pk_k * jnp.sum(kernel_a * pk_q, axis=-1)
pk11 = pk_k
pk13_dd = 2. * jnp.sum(kernel13_d * pk_q, axis=-1) * pk_k
pk13_tt = 2. * jnp.sum(kernel13_t * pk_q, axis=-1) * pk_k
pk13_dt = (pk13_dd + pk13_tt) / 2.
pk_sig3sq = sig3sq * pk_k
pk_dd = pk11 + pk22_dd + pk13_dd
pk_dt = pk11 + pk22_dt + pk13_dt
pk_tt = pk11 + pk22_tt + pk13_tt
return [pk11, pk_dd, pk_b2d, pk_bs2d, pk_sig3sq, pk_b22, pk_b2s2, pk_bs22, pk_dt, pk_b2t, pk_bs2t, pk_tt, A, B]
[docs]
class TNSPowerSpectrumMultipoles(BasePTPowerSpectrumMultipoles, BaseTheoryPowerSpectrumMultipolesFromWedges):
r"""
TNS power spectrum multipoles.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
mu : int, default=8
Number of :math:`\mu`-bins to use (in :math:`[0, 1]`).
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
"""
_default_options = dict(nloop=1, fog='lorentzian')
_klim = (1e-3, 2., 500)
def initialize(self, *args, mu=8, **kwargs):
super(TNSPowerSpectrumMultipoles, self).initialize(*args, mu=mu, method='leggauss', **kwargs)
self.nloop = int(self.options['nloop'])
if self.nloop not in [1]:
raise ValueError('nloop must be 1 (1-loop)')
if self.options['fog'] not in ['lorentzian', 'gaussian']:
raise ValueError('fog must be lorentzian or gaussian')
def calculate(self, sigmav=0):
super(TNSPowerSpectrumMultipoles, self).calculate()
jac, kap, muap = self.template.ap_k_mu(self.k, self.mu)
f = self.template.f
if self.options['fog'] == 'lorentzian':
damping = 1. / (1. + (sigmav * kap * muap)**2 / 2.)**2.
else:
damping = jnp.exp(-(sigmav * kap * muap)**2)
k11 = np.linspace(self.k[0] * 0.7, self.k[-1] * 1.3, int(len(self.k) * 1.6 + 0.5))
q = self.template.k
wq = utils.weights_trapz(q)
if getattr(self, 'kernels', None) is None:
self.kernels = tns_kernels(k11, q, wq)
pktable = tns_pt(k11, q, wq, self.template.pk_dd, *self.kernels)
names = ['pk11', 'pk_dd', 'pk_b2d', 'pk_bs2d', 'pk_sig3sq', 'pk_b22', 'pk_b2s2', 'pk_bs22', 'pk_dt', 'pk_b2t', 'pk_bs2t', 'pk_tt', 'A', 'B']
pktable = jnp.concatenate([array[None, :] for array in pktable[:-2]] + pktable[-2:], axis=0)
pktable = jac * damping * jnp.moveaxis(interp1d(jnp.log10(kap), np.log10(k11), pktable.T, method='cubic'), [0, 1], [1, 2])
A = pktable[12:]
B = pktable[17:]
#self._A = A
#self._B = np.array([B[0], -(B[1] + B[2]), B[3], B[4], -(B[5] + B[6]), B[7], -(B[8] + B[9]), B[10], B[11]])
A = jnp.array([f * A[0] * muap**2, f**2 * (A[1] * muap**2 + A[2] * muap**4), f**3 * (A[3] * muap**4 + A[4] * muap**6)]) # for b1^2, b1, 1
B = jnp.array([f**2 * (B[0] * muap**2 + B[4] * muap**4),
-f**3 * ((B[1] + B[2]) * muap**2 + (B[5] + B[6]) * muap**4 + (B[8] + B[9]) * muap**6),
f**4 * (B[3] * muap**2 + B[7] * muap**4 + B[10] * muap**6 + B[11] * muap**8)]) # for b1^2, b1, 1
pktable = [self.to_poles(pktable[:8, None]), self.to_poles(f * muap**2 * pktable[8:11, None]), self.to_poles(f**2 * muap**4 * pktable[11:12, None])]
self.pktable = {}
for pkt in pktable:
for pk in pkt: self.pktable[names[len(self.pktable)]] = pk
self.pktable['A'] = self.to_poles(A[:, None, ...])
self.pktable['B'] = self.to_poles(B[:, None, ...])
def __getstate__(self):
state = {}
for name in ['k', 'z', 'ells', 'nloop','fog']:
if hasattr(self, name):
state[name] = getattr(self, name)
for name in self.pktable:
state[name] = self.pktable[name]
state['names'] = list(self.pktable.keys())
return state
def __setstate__(self, state):
state = dict(state)
self.pktable = {name: state.pop(name, None) for name in state['names']}
super(TNSPowerSpectrumMultipoles, self).__setstate__(state)
[docs]
class TNSTracerPowerSpectrumMultipoles(BaseTracerPowerSpectrumMultipoles):
r"""
TNS tracer power spectrum multipoles.
For the matter (unbiased) power spectrum, set b1=1 and all other bias parameters to 0.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
mu : int, default=8
Number of :math:`\mu`-bins to use (in :math:`[0, 1]`).
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
shotnoise : float, default=1e4
Shot noise (which is usually marginalized over).
"""
_default_options = dict(freedom=None)
_deterministic_bias_params = ['b1', 'b2', 'bs', 'b3']
_stochastic_bias_params = ['sn0']
def set_params(self):
self.required_bias_params.update(dict(b1=1., b2=0., bs=0., b3=0., sn0=0.))
super().set_params(pt_params=['sigmav'])
freedom = self.options.get('freedom', None)
fix = []
if freedom == 'max':
for param in self.init.params.select(basename=['b1', 'b2', 'bs', 'b3']):
param.update(fixed=False)
fix += ['alpha6']
if freedom == 'min':
fix += ['b3', 'bs']
for param in self.init.params.select(basename=fix):
param.update(value=0., fixed=True)
def calculate(self, **kwargs):
bias_params = self.pack_input_bias_params(kwargs)
(b1, _), (b2, _), (bs, _), (b3, _), sn0 = [bias_params[name] for name in ['b1', 'b2', 'bs', 'b3', 'sn0']]
super(TNSTracerPowerSpectrumMultipoles, self).calculate()
self.power = b1**2 * self.pt.pktable['pk_dd'] + 2. * b1 * self.pt.pktable['pk_dt'] + self.pt.pktable['pk_tt'] + sn0 / self.nd
bs2 = bs - 4. / 7. * (b1 - 1.)
b3nl = b3 + 32. / 315. * (b1 - 1.)
#bs2 = b3nl = 0.
self.power += 2 * b1 * b2 * self.pt.pktable['pk_b2d'] + 2. * b1 * bs2 * self.pt.pktable['pk_bs2d']\
+ 2 * b1 * b3nl * self.pt.pktable['pk_sig3sq'] + b2**2 * self.pt.pktable['pk_b22']\
+ 2 * b2 * bs2 * self.pt.pktable['pk_b2s2'] + bs2**2 * self.pt.pktable['pk_bs22']\
+ b2 * self.pt.pktable['pk_b2t'] + b3nl * self.pt.pktable['pk_sig3sq']
self.power += b1**2 * (self.pt.pktable['A'][0] + self.pt.pktable['B'][0])
self.power += b1 * (self.pt.pktable['A'][1] + self.pt.pktable['B'][1])
self.power += (self.pt.pktable['A'][2] + self.pt.pktable['B'][2])
[docs]
class TNSTracerCorrelationFunctionMultipoles(BaseTracerCorrelationFunctionFromPowerSpectrumMultipoles):
r"""
TNS tracer correlation function multipoles.
For the matter (unbiased) correlation function, set b1=1 and all other bias parameters to 0.
Parameters
----------
s : array, default=None
Theory separations where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
**kwargs : dict
Options, defaults to: ``mu=8``.
"""
_deterministic_bias_params = TNSTracerPowerSpectrumMultipoles._deterministic_bias_params
_stochastic_bias_params = []
[docs]
class EFTLikeTNSTracerPowerSpectrumMultipoles(BaseEFTLikeTracerPowerSpectrumMultipoles, TNSTracerPowerSpectrumMultipoles):
r"""
TNS tracer power spectrum multipoles with EFT-like counter and stochastic terms.
Can be exactly marginalized over counter terms and stochastic parameters ct*, sn*.
For the matter (unbiased) power spectrum, set b1=1 and all other bias parameters to 0.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
mu : int, default=8
Number of :math:`\mu`-bins to use (in :math:`[0, 1]`).
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
shotnoise : float, default=1e4
Shot noise (which is usually marginalized over).
"""
_deterministic_bias_params = TNSTracerPowerSpectrumMultipoles._deterministic_bias_params + ['ct0_2', 'ct2_2', 'ct4_2']
_stochastic_bias_params = TNSTracerPowerSpectrumMultipoles._stochastic_bias_params + ['sn0_2', 'sn2_2', 'sn4_2']
[docs]
class EFTLikeTNSTracerCorrelationFunctionMultipoles(BaseTracerCorrelationFunctionFromPowerSpectrumMultipoles):
r"""
TNS tracer correlation function multipoles with EFT-like counter and stochastic terms.
Can be exactly marginalized over counter terms and stochastic parameters ct*, sn*.
For the matter (unbiased) correlation function, set b1=1 and all other bias parameters to 0.
Parameters
----------
s : array, default=None
Theory separations where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
**kwargs : dict
Options, defaults to: ``mu=8``.
"""
_deterministic_bias_params = EFTLikeTNSTracerPowerSpectrumMultipoles._deterministic_bias_params
_stochastic_bias_params = []
def get_nthreads(nthreads=None):
if nthreads is None:
import os
nthreads = os.getenv('OMP_NUM_THREADS', '1')
return int(nthreads)
[docs]
class BaseVelocileptorsPowerSpectrumMultipoles(BasePTPowerSpectrumMultipoles, BaseTheoryPowerSpectrumMultipolesFromWedges):
"""Base class for velocileptors-based matter power spectrum multipoles."""
_default_options = dict()
def initialize(self, *args, **kwargs):
super(BaseVelocileptorsPowerSpectrumMultipoles, self).initialize(*args, **kwargs)
self.options['threads'] = get_nthreads(self.options.pop('nthreads', None))
@classmethod
def install(cls, installer):
installer.pip('git+https://github.com/sfschen/velocileptors')
def __getstate__(self):
state = {}
for name in ['k', 'z', 'ells', 'wmu', 'sigma8', 'fsigma8']:
if hasattr(self, name):
state[name] = getattr(self, name)
for name in self._pt_attrs:
if hasattr(self.pt, name):
state[name] = getattr(self.pt, name)
return state
def get_physical_stochastic_settings(tracer=None):
if tracer is not None:
tracer = str(tracer).upper()
# Mark Maus, Ruiyang Zhao
settings = {'BGS': {'fsat': 0.15, 'sigv': 150*(10)**(1/3)*(1+0.2)**(1/2)/70.},
'LRG': {'fsat': 0.15, 'sigv': 150*(10)**(1/3)*(1+0.8)**(1/2)/70.},
'ELG': {'fsat': 0.10, 'sigv': 150*2.1**(1/2)/70.},
'QSO': {'fsat': 0.03, 'sigv': 150*(10)**(0.7/3)*(2.4)**(1/2)/70.}}
try:
settings = settings[tracer]
except KeyError:
raise ValueError('unknown tracer: {}, please use any of {}'.format(tracer, list(settings.keys())))
else:
settings = {'fsat': 0.1, 'sigv': 5.}
return settings
[docs]
class BaseVelocileptorsTracerPowerSpectrumMultipoles(BaseTracerPowerSpectrumMultipoles):
"""Base class for velocileptors-based tracer power spectrum multipoles."""
@classmethod
def _params(cls, params, freedom=None, prior_basis='physical', tracers=None):
fix = []
if freedom == 'max':
for param in params.select(basename=['b1', 'b2', 'bs', 'b3']):
param.update(fixed=False)
for param in params.select(basename=['b2', 'bs', 'b3']):
param.update(prior=dict(limits=[-15., 15.]))
for param in params.select(basename=['alpha*', 'sn*']):
param.update(prior=None)
fix += ['alpha6'] #, 'sn4']
if freedom == 'min':
fix += ['b3', 'bs', 'alpha6'] #, 'sn4']
for param in params.select(basename=['b2']):
param.update(prior=dict(dist='norm', loc=0., scale=10.))
for param in params.select(basename=['alpha*', 'sn*']):
param.update(prior=None)
for param in params.select(basename=fix):
param.update(value=0., fixed=True)
# call `BaseTracerTwoPointTheory._params.__func__` here as classmethod `_params` is a descriptor
params = BaseTracerTwoPointTheory._params.__func__(cls, params, tracers=tracers)
if prior_basis == 'physical':
for param in list(params):
basename = param.basename
param.update(basename=basename + 'p')
#params.set({'basename': basename, 'namespace': param.namespace, 'derived': True})
for param in params.select(basename='b1p'):
param.update(prior=dict(dist='uniform', limits=[0., 3.]), ref=dict(dist='norm', loc=1., scale=0.1))
for param in params.select(basename=['b2p', 'bsp', 'b3p']):
param.update(prior=dict(dist='norm', loc=0., scale=5.), ref=dict(dist='norm', loc=0., scale=1.))
for param in params.select(basename='b3p'):
param.update(value=0., fixed=True)
for param in params.select(basename='alpha*p'):
param.update(prior=dict(dist='norm', loc=0., scale=12.5), ref=dict(dist='norm', loc=0., scale=1.)) # 50% at k = 0.2 h/Mpc
for param in params.select(basename='sn*p'):
param.update(prior=dict(dist='norm', loc=0., scale=2. if 'sn0' in param.basename else 5.), ref=dict(dist='norm', loc=0., scale=1.))
return params
def set_params(self):
self.is_physical_prior = self.options['prior_basis'] == 'physical'
if self.is_physical_prior:
for name in list(self.required_bias_params):
self.required_bias_params[name + 'p'] = self.required_bias_params.pop(name)
settings = get_physical_stochastic_settings(tracer=self.options['tracer'])
for name, value in settings.items():
if self.options[name] is None: self.options[name] = value
if self.mpicomm.rank == 0:
self.log_debug('Using fsat, sigv = {:.3f}, {:.3f}.'.format(self.options['fsat'], self.options['sigv']))
self.deterministic_bias_params = [name + 'p' for name in self.deterministic_bias_params]
self.stochastic_bias_params = [name + 'p' for name in self.stochastic_bias_params]
super().set_params(pt_params=[])
fix = []
if 4 not in self.ells: fix += ['alpha4*', 'alpha6*', 'sn4*'] # * to capture p
if 2 not in self.ells: fix += ['alpha2*', 'sn2*']
for param in self.init.params.select(basename=fix):
param.update(value=0., fixed=True)
self.nd = 1e-4
self.fsat = self.snd = 1.
if self.is_physical_prior:
self.fsat, self.snd = self.options['fsat'], self.options['shotnoise'] * self.nd # normalized by 1e-4
[docs]
class BaseVelocileptorsCorrelationFunctionMultipoles(BasePTCorrelationFunctionMultipoles):
"""Base class for velocileptors-based matter correlation function multipoles."""
def initialize(self, *args, **kwargs):
super(BaseVelocileptorsCorrelationFunctionMultipoles, self).initialize(*args, **kwargs)
self.options['threads'] = get_nthreads(self.options.pop('nthreads', None))
def combine_bias_terms_poles(self, pars, **opts):
return np.array([self.pt.compute_xi_ell(ss, self.template.f, *pars, apar=self.template.qpar, aperp=self.template.qper, **self.options, **opts) for ss in self.s]).T
[docs]
class BaseVelocileptorsTracerCorrelationFunctionMultipoles(BaseTracerCorrelationFunctionMultipoles):
"""Base class for velocileptors-based tracer correlation function multipoles."""
def calculate(self, **params):
super(BaseVelocileptorsTracerCorrelationFunctionMultipoles, self).calculate()
pars = [params.get(name, value) for name, value in self.required_bias_params.items()]
opts = {name: params.get(name, default) for name, default in self.optional_bias_params.items()}
self.corr = self.pt.combine_bias_terms_poles(pars, **opts, **self.options)
@jit
def tablevel_combine_bias_terms_poles(pktable, pars, nd=1e-4):
b1, b2, bs, b3, alpha0, alpha2, alpha4, alpha6, sn0, sn2, sn4 = pars
bias_monomials = jnp.array([1, b1, b1**2, b2, b1 * b2, b2**2, bs, b1 * bs, b2 * bs, bs**2, b3, b1 * b3, alpha0, alpha2, alpha4, alpha6, sn0 / nd, sn2 / nd, sn4 / nd])
return jnp.sum(pktable * bias_monomials, axis=-1)
[docs]
class LPTVelocileptorsPowerSpectrumMultipoles(BaseVelocileptorsPowerSpectrumMultipoles):
_default_options = dict(use_Pzel=False, kIR=0.2, cutoff=10, extrap_min=-5, extrap_max=3, N=4000, nthreads=None, jn=5)
# Speed is linear with the number of output k
def initialize(self, *args, mu=4, **kwargs):
super(LPTVelocileptorsPowerSpectrumMultipoles, self).initialize(*args, mu=mu, method='leggauss', **kwargs)
def calculate(self):
super(LPTVelocileptorsPowerSpectrumMultipoles, self).calculate()
def interp1d(x, y):
return interpolate.interp1d(x, y, kind='cubic', assume_sorted=True) # for AP
from velocileptors.LPT import lpt_rsd_fftw
lpt_rsd_fftw.interp1d = interp1d
from velocileptors.LPT.lpt_rsd_fftw import LPT_RSD
self.pt = LPT_RSD(np.asarray(self.template.k), np.asarray(self.template.pk_dd), **self.options)
self.pt.make_pltable(np.asarray(self.template.f), kv=np.asarray(self.k), apar=np.asarray(self.template.qpar), aperp=np.asarray(self.template.qper), ngauss=len(self.mu))
pktable = {0: self.pt.p0ktable, 2: self.pt.p2ktable, 4: self.pt.p4ktable}
self.pktable = np.array([pktable[ell] for ell in self.ells])
self.sigma8 = self.template.sigma8
self.fsigma8 = self.template.f * self.sigma8
def combine_bias_terms_poles(self, pars, nd=1e-4):
return tablevel_combine_bias_terms_poles(self.pktable, pars, nd=nd)
def __getstate__(self):
state = {}
for name in ['k', 'z', 'ells', 'pktable', 'sigma8', 'fsigma8']:
if hasattr(self, name):
state[name] = getattr(self, name)
return state
@classmethod
def install(cls, installer):
installer.pip('git+https://github.com/sfschen/velocileptors')
[docs]
class LPTVelocileptorsTracerPowerSpectrumMultipoles(BaseVelocileptorsTracerPowerSpectrumMultipoles):
r"""
Velocileptors Lagrangian perturbation theory (LPT) tracer power spectrum multipoles.
Can be exactly marginalized over counter terms and stochastic parameters alpha*, sn*.
For the matter (unbiased) power spectrum, set all bias parameters to 0.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
prior_basis : str, default='physical'
If 'physical', use physically-motivated prior basis for bias parameters, counterterms and stochastic terms:
:math:`b_{1}^\prime = (1 + b_{1}) \sigma_{8}(z), b_{2}^\prime = b_{2} \sigma_{8}(z)^2, b_{s}^\prime = b_{s} \sigma_{8}(z)^2, b_{3}^\prime = b_{3} \sigma_{8}(z)^3`
:math:`\alpha_{0} = (1 + b_{1})^{2} \alpha_{0}^\prime, \alpha_{2} = f (1 + b_{1}) (\alpha_{0}^\prime + \alpha_{2}^\prime), \alpha_{4} = f (f \alpha_{2}^\prime + (1 + b_{1}) \alpha_{4}^\prime), \alpha_{6} = f^{2} \alpha_{4}^\prime`.
:math:`s_{n, 0} = f_{\mathrm{sat}}/\bar{n} s_{n, 0}^\prime, s_{n, 2} = f_{\mathrm{sat}}/\bar{n} \sigma_{v}^{2} s_{n, 2}^\prime, s_{n, 4} = f_{\mathrm{sat}}/\bar{n} \sigma_{v}^{4} s_{n, 4}^\prime`.
In this case, ``use_Pzel = False``.
tracer : str, default=None
If ``prior_basis = 'physical'``, tracer to load preset ``fsat`` and ``sigv``. One of ['LRG', 'ELG', 'QSO'].
fsat : float, default=None
If ``prior_basis = 'physical'``, satellite fraction to assume.
sigv : float, default=None
If ``prior_basis = 'physical'``, velocity dispersion to assume.
shotnoise : float, default=1e4
Shot noise, to scale stochastic terms.
**kwargs : dict
Velocileptors options, defaults to: ``use_Pzel=False, kIR=0.2, cutoff=10, extrap_min=-5, extrap_max=3, N=4000, nthreads=1, jn=5``.
Reference
---------
- https://arxiv.org/abs/2005.00523
- https://arxiv.org/abs/2012.04636
- https://github.com/sfschen/velocileptors
"""
_default_options = dict(freedom=None, prior_basis='physical', tracer=None, fsat=None, sigv=None, shotnoise=1e4)
_deterministic_bias_params = ['b1', 'b2', 'bs', 'b3', 'alpha0', 'alpha2', 'alpha4', 'alpha6']
_stochastic_bias_params = ['sn0', 'sn2', 'sn4']
def initialize(self, *args, k=None, **kwargs):
super(LPTVelocileptorsTracerPowerSpectrumMultipoles, self).initialize(*args, **kwargs)
if k is not None:
self.k = np.array(k, dtype='f8')
# Increasing the resolution, necessary
boost_prec = 2
kvec = np.concatenate([[min(0.0005, self.k[0])], np.geomspace(0.0015, 0.025, 10 * boost_prec, endpoint=True), np.arange(0.03, max(0.5, self.k[-1]) + 0.015 / boost_prec, 0.01 / boost_prec)]) # margin for interpolation below (and numerical noise in endpoint)
ells = kwargs.get('ells', None)
if ells is not None: self.ells = tuple(ells)
self.pt.init.update(k=kvec, ells=self.ells, use_Pzel=not self.is_physical_prior)
def set_params(self):
self.required_bias_params = {param: 0. for param in self._deterministic_bias_params + self._stochastic_bias_params}
self.required_bias_params['b1'] = 1.
super().set_params()
def calculate(self, **params):
for name in ['z']:
setattr(self, name, getattr(self.pt, name))
params = self.pack_input_bias_params(params)
params = {name: value[0] if isinstance(value, tuple) else value for name, value in params.items()}
if self.is_physical_prior:
sigma8 = self.pt.sigma8
f = self.pt.fsigma8 / sigma8
pars = b1L, b2L, bsL, b3L = [params['b1p'] / sigma8 - 1., params['b2p'] / sigma8**2, params['bsp'] / sigma8**2, params['b3p'] / sigma8**3]
pars += [(1 + b1L)**2 * params['alpha0p'], f * (1 + b1L) * (params['alpha0p'] + params['alpha2p']),
f * (f * params['alpha2p'] + (1 + b1L) * params['alpha4p']), f**2 * params['alpha4p']]
sigv = self.options['sigv']
pars += [params['sn{:d}p'.format(i)] * self.snd * (self.fsat if i > 0 else 1.) * sigv**i for i in [0, 2, 4]]
else:
pars = [params[name] for name in self.required_bias_params]
#self.__dict__.update(dict(zip(['b1', 'b2', 'bs', 'b3', 'alpha0', 'alpha2', 'alpha4', 'alpha6', 'sn0', 'sn2', 'sn4'], pars))) # for derived parameters
opts = {name: params.get(name, default) for name, default in self.optional_bias_params.items()}
index = np.array([self.pt.ells.index(ell) for ell in self.ells])
self.power = interp1d(self.k, self.pt.k, self.pt.combine_bias_terms_poles(pars, **opts, nd=self.nd)[index].T).T
#self.power = self.pt.combine_bias_terms_poles(pars, **opts, nd=self.nd)
[docs]
class LPTVelocileptorsTracerCorrelationFunctionMultipoles(BaseTracerCorrelationFunctionFromPowerSpectrumMultipoles):
r"""
Velocileptors LPT tracer correlation function multipoles.
Can be exactly marginalized over counter terms and stochastic parameters alpha*, sn*.
For the matter (unbiased) correlation function, set all bias parameters to 0.
Parameters
----------
s : array, default=None
Theory separations where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
prior_basis : str, default='physical'
If 'physical', use physically-motivated prior basis for bias parameters, counterterms and stochastic terms:
:math:`b_{1}^\prime = (1 + b_{1}) \sigma_{8}(z), b_{2}^\prime = b_{2} \sigma_{8}(z)^2, b_{s}^\prime = b_{s} \sigma_{8}(z)^2, b_{3}^\prime = b_{3} \sigma_{8}(z)^3`
:math:`\alpha_{0} = (1 + b_{1})^{2} \alpha_{0}^\prime, \alpha_{2} = f (1 + b_{1}) (\alpha_{0}^\prime + \alpha_{2}^\prime), \alpha_{4} = f (f \alpha_{2}^\prime + (1 + b_{1}) \alpha_{4}^\prime), \alpha_{6} = f^{2} \alpha_{4}^\prime`.
**kwargs : dict
Velocileptors options, defaults to: ``use_Pzel=False, kIR=0.2, cutoff=10, extrap_min=-5, extrap_max=3, N=4000, nthreads=1, jn=5``.
Reference
---------
- https://arxiv.org/abs/2005.00523
- https://arxiv.org/abs/2012.04636
- https://github.com/sfschen/velocileptors
"""
_params = classmethod(LPTVelocileptorsTracerPowerSpectrumMultipoles._params.__func__)
_deterministic_bias_params = LPTVelocileptorsTracerPowerSpectrumMultipoles._deterministic_bias_params
_stochastic_bias_params = []
def set_params(self):
super().set_params()
if self.power.is_physical_prior:
self.deterministic_bias_params = [name + 'p' for name in self.deterministic_bias_params]
self.stochastic_bias_params = [name + 'p' for name in self.stochastic_bias_params]
[docs]
def f_over_f0_EH(z, k, Omega0_m, h, fnu, Nnu=3, Neff=3.044):
r"""
Computes f(k)/f0, adapted from https://github.com/henoriega/FOLPS-nu, following H&E (1998).
Reference
---------
https://arxiv.org/pdf/astro-ph/9710216
Parameters
----------
z : float
Redshift.
k : array
Wavenumber.
Omega0_m : float
:math:`\Omega_\mathrm{b} + \Omega_\mathrm{c} + \Omega_\nu` (dimensionless matter density parameter).
h : float
:math:`H_0 / 100`.
fnu : float
:math:`\Omega_\nu / \Omega_\mathrm{m}`.
Nnu : int, default=3
Number of massive neutrinos.
Neff : int, default=3.044
Effective number of relativistic species.
Returns
-------
fk : array
:math:`f(k) / f0`
"""
eta = jnp.log(1 / (1 + z)) # log of scale factor
Omega0_r = 2.469*10**(-5)/(h**2 * (1 + 7/8*(4/11)**(4/3) * Neff)) # rad: including neutrinos
aeq = Omega0_r / Omega0_m # matter-radiation equality
pcb = 5./4 - jnp.sqrt(1 + 24*(1 - fnu)) / 4 # neutrino supression
c = 0.7
theta272 = (1.00)**2 # T_{CMB} = 2.7*(theta272)
pf = (k * theta272) / (Omega0_m * h**2)
DEdS = jnp.exp(eta) / aeq # growth function: EdS cosmology
fnunonzero = jnp.where(fnu != 0., fnu, 1.)
yFS = 17.2*fnu*(1 + 0.488*fnunonzero**(-7/6)) * (pf*Nnu / fnunonzero)**2 #yFreeStreaming
# pcb = 0. and yFS = 0. when fnu = 0.
rf = DEdS/(1 + yFS)
return 1 - pcb/(1 + (rf)**c) # f(k)/f0
[docs]
class REPTVelocileptorsPowerSpectrumMultipoles(BaseVelocileptorsPowerSpectrumMultipoles):
_default_options = dict(rbao=110, sbao=None, beyond_gauss=True,
one_loop=True, shear=True, cutoff=20, jn=5, N=4000,
nthreads=None, extrap_min=-4, extrap_max=3, import_wisdom=False)
# Speed does not depend on the number of output k
def initialize(self, *args, mu=4, **kwargs):
super(REPTVelocileptorsPowerSpectrumMultipoles, self).initialize(*args, mu=mu, method='leggauss', **kwargs)
self.template.init.update(with_now='peakaverage')
def _emulator_initialize(self):
self._emulator_bak = getattr(self, '_emulator_bak', self.emulator)
self.emulator = self._emulator_bak.deepcopy()
if 'z' not in self.init: return
z = np.asarray(self.init['z'])
allz = self.emulator.fixed['z']
if np.allclose(z, allz): return
if np.any((z < allz[0]) | (z > allz[-1])):
raise ValueError('input z = {} is outside of the range of emub1 lated z: {} - {}'.format(z, *allz[[0, -1]]))
iz = np.searchsorted(allz, z, side='right') - 1
izp1 = np.minimum(iz + 1, len(allz) - 1)
keepiz = np.unique(np.concatenate([iz, izp1], axis=0))
allz = allz[keepiz]
self.emulator.fixed['z'] = z
iz = np.searchsorted(keepiz, iz, side='right') - 1
izp1 = np.minimum(iz + 1, len(allz) - 1)
wz = z - allz[iz]
from desilike.emulators import Operation
# Keep only the iz predictions we are interested in (for jaxeffort, maybe we should fix this later)
for name, engine in self.emulator.engines.items():
for operation in engine.model_operations + engine.yoperations:
operation.update(locals={name: value[keepiz] for name, value in operation._locals.items()})
engine.yshape = keepiz.shape + engine.yshape[1:]
self.emulator.yoperations.insert(0, Operation("", "{name: v[name][..., iz] * (1 - wz) + v[name][..., iz + 1] * wz if name in ['pktable', 'fsigma8', 'sigma8'] else v[name] for name in v}", locals={'wz': wz, 'iz': iz}))
def calculate(self):
super(REPTVelocileptorsPowerSpectrumMultipoles, self).calculate()
from velocileptors.EPT.ept_fullresum_varyDz_nu_fftw import REPT
#from velocileptors.EPT.ept_fullresum_fftw import REPT
pk_dd, pknow_dd = self.template.pk_dd, self.template.pknow_dd
#print('desilike', self.template.k.min(), self.template.k.max(), self.template.k.size, self.template.pk_dd.sum())
if self.z.ndim: pk_dd, pknow_dd = pk_dd[..., 0], pknow_dd[..., 0]
self.pt = REPT(np.asarray(self.template.k), np.asarray(pk_dd), pnw=np.asarray(pknow_dd), kmin=self.k[0], kmax=self.k[-1], nk=200, **self.options)
# print(self.template.f, self.k.shape, self.template.qpar, self.template.qper, self.template.k.shape, self.template.pk_dd.shape)
pktable = {ell: [] for ell in [0, 2, 4]}
self.sigma8 = self.template.sigma8
self.fsigma8 = self.template.f * self.sigma8
Omega_m, h, fnu, Neff, Nnu = 0.3, 0.7, 0., 3.046, 3
#cosmo = getattr(self.template, 'cosmo', None)
#if cosmo is not None:
# Omega_m, h, fnu, Nnu, Neff = cosmo['Omega_m'], cosmo['h'], cosmo['Omega_ncdm_tot'] / cosmo['Omega_m'], cosmo['N_ncdm'], cosmo['N_eff']
f0, qpar, qper = map(np.asarray, [self.template.f0, self.template.qpar, self.template.qper])
pcb, pcb_nw, pttcb = [10**interpolate.interp1d(np.log10(self.template.k), np.log10(pk), kind='cubic', fill_value='extrapolate', axis=0, assume_sorted=True)(np.log10(np.append(self.pt.kv, 1.))) for pk in [self.template.pk_dd, self.template.pknow_dd, self.template.pk_dd * self.template.fk**2]]
fk = np.sqrt(pttcb / pcb)[:-1]
if self.z.ndim:
for iz, z in enumerate(self.z):
Dz = np.sqrt(pcb[-1, iz] / pcb[-1, 0])
#fk = f0[iz] * f_over_f0_EH(z, self.pt.kv, Omega_m, h, fnu, Nnu=Nnu, Neff=Neff)
#print(Dz, pcb[:-1, iz].sum(), pcb_nw[:-1, iz].sum(), fk[..., iz].sum())
pks = self.pt.compute_redshift_space_power_multipoles_tables(fk[..., iz], apar=qpar[iz], aperp=qper[iz], ngauss=len(self.mu), pcb=pcb[:-1, iz], pcb_nw=pcb_nw[:-1, iz], Dz=Dz)[1:]
for ill, ell in enumerate(pktable): pktable[ell].append(pks[ill])
pktable = {ell: np.concatenate([v[..., None] for v in value], axis=-1) for ell, value in pktable.items()}
else:
#fk = f0 * f_over_f0_EH(self.z, self.pt.kv, Omega_m, h, fnu, Nnu=Nnu, Neff=Neff)
pks = self.pt.compute_redshift_space_power_multipoles_tables(fk, apar=qpar, aperp=qper, ngauss=len(self.mu))[1:]
for ill, ell in enumerate(pktable): pktable[ell] = pks[ill]
self.pktable = interpolate.interp1d(self.pt.kv, np.array([pktable[ell] for ell in self.ells]), kind='cubic', fill_value='extrapolate', axis=1, assume_sorted=True)(self.k)
def combine_bias_terms_poles(self, pars, z=None, nd=1e-4):
# Add co-evolution part
pars = list(pars)
b1 = pars[0]
pars[2] = pars[2] - (2 / 7) * (b1 - 1.) # bs
pars[3] = 3 * pars[3] + (b1 - 1.) # b3
#return interpolate.interp1d(self.pt.kv, np.array(self.pt.compute_redshift_space_power_multipoles(pars, self.template.f)[1:]), kind='cubic', fill_value='extrapolate', axis=1, assume_sorted=True)(self.k)
pktable = self.pktable
if z is not None: pktable = pktable[..., list(self.z).index(z)]
return tablevel_combine_bias_terms_poles(pktable, pars, nd=nd)
def __getstate__(self, varied=True, fixed=True):
state = {}
for name in (['k', 'z', 'ells'] if fixed else []) + (['pktable', 'sigma8', 'fsigma8'] if varied else []):
if hasattr(self, name):
state[name] = getattr(self, name)
return state
@classmethod
def install(cls, installer):
installer.pip('git+https://github.com/sfschen/velocileptors')
[docs]
class REPTVelocileptorsTracerPowerSpectrumMultipoles(BaseVelocileptorsTracerPowerSpectrumMultipoles):
r"""
Velocileptors resummmed Eulerian perturbation theory (REPT) tracer power spectrum multipoles.
Can be exactly marginalized over counter terms and stochastic parameters alpha*, sn*.
For the matter (unbiased) power spectrum, set all bias parameters to 0.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
prior_basis : str, default='physical'
If 'physical', use physically-motivated prior basis for bias parameters, counterterms and stochastic terms:
:math:`b_{1}^\prime = (1 + b_{1}^{L}) \sigma_{8}(z), b_{2}^\prime = b_{2}^{L} \sigma_{8}(z)^2, b_{s}^\prime = b_{s}^{L} \sigma_{8}(z)^2, b_{3}^\prime = 0`
with: :math:`b_{1} = 1 + b_{1}^{L}, b_{2} = 8/21 b_{1}^{L} + b_{2}^{L}, b_{s} = b_{s}^{L}, b_{3} = b_{3}^{L}`.
:math:`\alpha_{0} = (1 + b_{1}^{L})^{2} \alpha_{0}^\prime, \alpha_{2} = f (1 + b_{1}^{L}) (\alpha_{0}^\prime + \alpha_{2}^\prime), \alpha_{4} = f (f \alpha_{2}^\prime + (1 + b_{1}^{L}) \alpha_{4}^\prime)`.
:math:`s_{n, 0} = f_{\mathrm{sat}}/\bar{n} s_{n, 0}^\prime, s_{n, 2} = f_{\mathrm{sat}}/\bar{n} \sigma_{v}^{2} s_{n, 2}^\prime, s_{n, 4} = f_{\mathrm{sat}}/\bar{n} \sigma_{v}^{4} s_{n, 4}^\prime`.
tracer : str, default=None
If ``prior_basis = 'physical'``, tracer to load preset ``fsat`` and ``sigv``. One of ['LRG', 'ELG', 'QSO'].
fsat : float, default=None
If ``prior_basis = 'physical'``, satellite fraction to assume.
sigv : float, default=None
If ``prior_basis = 'physical'``, velocity dispersion to assume.
shotnoise : float, default=1e4
Shot noise, to scale stochastic terms.
**kwargs : dict
Velocileptors options, defaults to: ``rbao=110, sbao=None, beyond_gauss=True, one_loop=True, shear=True, cutoff=20, jn=5, N=4000, nthreads=None, extrap_min=-4, extrap_max=3``.
Reference
---------
- https://arxiv.org/abs/2005.00523
- https://arxiv.org/abs/2012.04636
- https://github.com/sfschen/velocileptors
"""
_default_options = dict(freedom=None, prior_basis='physical', tracer=None, fsat=None, sigv=None, shotnoise=1e4)
_deterministic_bias_params = ['b1', 'b2', 'bs', 'b3', 'alpha0', 'alpha2', 'alpha4', 'alpha6']
_stochastic_bias_params = ['sn0', 'sn2', 'sn4']
def initialize(self, *args, k=None, z=None, **kwargs):
super(REPTVelocileptorsTracerPowerSpectrumMultipoles, self).initialize(*args, **kwargs)
if k is not None:
self.k = np.array(k, dtype='f8')
# Increasing the resolution, necessary
boost_prec = 4
kvec = np.concatenate([[min(0.0005, self.k[0])], np.geomspace(0.0015, 0.025, 10 * boost_prec, endpoint=True), np.arange(0.03, max(0.5, self.k[-1]) + 0.015 / boost_prec, 0.01 / boost_prec)]) # margin for interpolation below (and numerical noise in endpoint)
ells = kwargs.get('ells', None)
if ells is not None: self.ells = tuple(ells)
self.pt.init.update(k=kvec, ells=self.ells)
if z is not None:
self.z = float(z)
z = self.pt.init.get('z', [])
if self.z not in z: z.append(self.z)
self.pt.init.update(z=sorted(z))
def set_params(self):
self.required_bias_params = {param: 0. for param in self._deterministic_bias_params + self._stochastic_bias_params}
self.required_bias_params['b1'] = 1.
super().set_params()
def calculate(self, **params):
if self.pt.z.ndim == 0: self.z = self.pt.z
params = self.pack_input_bias_params(params)
params = {name: value[0] if isinstance(value, tuple) else value for name, value in params.items()}
if self.is_physical_prior:
sigma8 = self.pt.sigma8
f = self.pt.fsigma8 / sigma8
if self.pt.z.ndim:
iz = list(self.pt.z).index(self.z)
sigma8, f = sigma8[iz], f[iz]
# b1_E = 1 + b1_L
# b2_E = b2_L + (8/21)*b1_L
# bs_E = bs_L - (2/7)*b1_L
# b3_E = 3*b3_L + b1_L
pars = b1L, b2L, bsL, b3L = [params['b1p'] / sigma8 - 1., params['b2p'] / sigma8**2, params['bsp'] / sigma8**2, params['b3p'] / sigma8**3]
pars = [1. + b1L, 8. / 21. * b1L + b2L, bsL, b3L]
pars += [(1 + b1L)**2 * params['alpha0p'], f * (1 + b1L) * (params['alpha0p'] + params['alpha2p']),
f * (f * params['alpha2p'] + (1 + b1L) * params['alpha4p']), f**2 * params['alpha4p']]
sigv = self.options['sigv']
pars += [params['sn{:d}p'.format(i)] * self.snd * (self.fsat if i > 0 else 1.) * sigv**i for i in [0, 2, 4]]
else:
pars = [params[name] for name in self.required_bias_params]
#self.__dict__.update(dict(zip(['b1', 'b2', 'bs', 'b3', 'alpha0', 'alpha2', 'alpha4', 'alpha6', 'sn0', 'sn2', 'sn4'], pars))) # for derived parameters
opts = {name: params.get(name, default) for name, default in self.optional_bias_params.items()}
index = np.array([self.pt.ells.index(ell) for ell in self.ells])
if self.pt.z.ndim: opts['z'] = self.z
self.power = interp1d(self.k, self.pt.k, self.pt.combine_bias_terms_poles(pars, **opts, nd=self.nd)[index].T).T
#self.power = self.pt.combine_bias_terms_poles(pars, **opts, nd=self.nd)
[docs]
class REPTVelocileptorsTracerCorrelationFunctionMultipoles(BaseTracerCorrelationFunctionFromPowerSpectrumMultipoles):
r"""
Velocileptors REPT tracer correlation function multipoles.
Can be exactly marginalized over counter terms and stochastic parameters alpha*, sn*.
For the matter (unbiased) correlation function, set all bias parameters to 0.
Parameters
----------
s : array, default=None
Theory separations where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
prior_basis : str, default='physical'
If 'physical', use physically-motivated prior basis for bias parameters, counterterms and stochastic terms:
:math:`b_{1}^\prime = (1 + b_{1}^{L}) \sigma_{8}(z), b_{2}^\prime = b_{2}^{L} \sigma_{8}(z)^2, b_{s}^\prime = b_{s}^{L} \sigma_{8}(z)^2, b_{3}^\prime = 0`
with: :math:`b_{1} = 1 + b_{1}^{L}, b_{2} = 8/21 b_{1}^{L} + b_{2}^{L}, b_{s} = b_{s}^{L}, b_{3} = b_{3}^{L}`.
:math:`\alpha_{0} = (1 + b_{1}^{L})^{2} \alpha_{0}^\prime, \alpha_{2} = f (1 + b_{1}^{L}) (\alpha_{0}^\prime + \alpha_{2}^\prime), \alpha_{4} = f (f \alpha_{2}^\prime + (1 + b_{1}^{L}) \alpha_{4}^\prime)`.
**kwargs : dict
Velocileptors options, defaults to: ``rbao=110, sbao=None, beyond_gauss=True, one_loop=True, shear=True, cutoff=20, jn=5, N=4000, nthreads=None, extrap_min=-4, extrap_max=3``.
Reference
---------
- https://arxiv.org/abs/2005.00523
- https://arxiv.org/abs/2012.04636
- https://github.com/sfschen/velocileptors
"""
_params = classmethod(REPTVelocileptorsTracerPowerSpectrumMultipoles._params.__func__)
_deterministic_bias_params = REPTVelocileptorsTracerPowerSpectrumMultipoles._deterministic_bias_params
_stochastic_bias_params = []
def set_params(self):
super().set_params()
if self.power.is_physical_prior:
self.deterministic_bias_params = [name + 'p' for name in self.deterministic_bias_params]
self.stochastic_bias_params = [name + 'p' for name in self.stochastic_bias_params]
[docs]
class PyBirdPowerSpectrumMultipoles(BasePTPowerSpectrumMultipoles):
_default_options = dict(km=0.7, kr=0.25, accboost=1, fftaccboost=1, fftbias=-1.6, with_nnlo_counterterm=False, with_stoch=True, with_resum='full', with_ap=True, eft_basis='eftoflss')
_klim = (1e-3, 11., 3000) # numerical instability in pybird's fftlog at 10.
_pt_attrs = ['co', 'f', 'eft_basis', 'with_stoch', 'with_nnlo_counterterm', 'with_tidal_alignments',
'P11l', 'Ploopl', 'Pctl', 'Pstl', 'Pnnlol', 'C11l', 'Cloopl', 'Cctl', 'Cstl', 'Cnnlol']
def initialize(self, *args, **kwargs):
super(PyBirdPowerSpectrumMultipoles, self).initialize(*args, **kwargs)
# self.co is fixed, so we can just export it in __getstate__
from pybird.common import Common
from pybird.nonlinear import NonLinear
from pybird.nnlo import NNLO_counterterm
from pybird.resum import Resum
from pybird.projection import Projection
eft_basis = self.options.get('eft_basis', None)
if eft_basis in [None, 'velocileptors']: eft_basis = 'eftoflss'
# nd used by combine_bias_terms_poles only
#self.co = Common(Nl=len(self.ells), kmin=self.k[0] * 0.8, kmax=self.k[-1] * 1.2, km=self.options['km'], kr=self.options['kr'], nd=1e-4,
# No way to go below kmin = 1e-3 h/Mpc (nan)
if self.k[0] * 0.8 < 1e-3:
import warnings
warnings.warn('pybird does not predict P(k) for k < 0.001 h/Mpc; nan will be replaced by 0')
for name in ['km', 'kr']:
self.options[name] = tuple(self.options[name]) if utils.is_sequence(self.options[name]) else (self.options[name],) * 2
self.km = self.options['km']
self.kr = self.options['kr']
self.co = Common(Nl=len(self.ells), kmin=1e-3, kmax=self.k[-1] * 1.3, km=min(self.options['km']), kr=min(self.options['kr']), nd=1e-4,
eft_basis=eft_basis, halohalo=True, with_cf=False,
with_time=True, accboost=float(self.options['accboost']), optiresum=self.options['with_resum'] == 'opti', with_uvmatch=False,
exact_time=False, quintessence=False, with_tidal_alignments=False, nonequaltime=False, keep_loop_pieces_independent=False)
#print(dict(Nl=len(self.ells), kmin=1e-3, kmax=self.k[-1] * 1.3, km=self.options['km'], kr=self.options['kr'], nd=1e-4,
# eft_basis=eft_basis, halohalo=True, with_cf=False,
# with_time=True, accboost=float(self.options['accboost']), optiresum=self.options['with_resum'] == 'opti',
# exact_time=False, quintessence=False, with_tidal_alignments=False, nonequaltime=False, keep_loop_pieces_independent=False))
self.nonlinear = NonLinear(load=False, save=False, NFFT=256 * int(self.options['fftaccboost']), fftbias=self.options['fftbias'], co=self.co)
#print(dict(load=False, save=False, NFFT=256 * int(self.options['fftaccboost']), fftbias=self.options['fftbias'], co=self.co))
self.resum = Resum(co=self.co)
self.nnlo_counterterm = None
if self.options['with_nnlo_counterterm']:
self.nnlo_counterterm = NNLO_counterterm(co=self.co)
self.template.init.update(with_now='peakaverage')
self.projection = Projection(self.k, with_ap=self.options['with_ap'], H_fid=None, D_fid=None, co=self.co) # placeholders for H_fid and D_fid, as we will provide q's
def calculate(self):
super(PyBirdPowerSpectrumMultipoles, self).calculate()
from pybird.bird import Bird
cosmo = {'kk': self.template.k, 'pk_lin': self.template.pk_dd, 'pk_lin_2': None, 'f': self.template.f, 'DA': 1., 'H': 1.}
self.pt = Bird(cosmo, with_bias=False, eft_basis=self.co.eft_basis, with_stoch=self.options['with_stoch'], with_nnlo_counterterm=self.nnlo_counterterm is not None, co=self.co)
if self.nnlo_counterterm is not None: # we use smooth power spectrum since we don't want spurious BAO signals
from scipy import interpolate
self.nnlo_counterterm.Ps(self.pt, interpolate.interp1d(np.log(self.template.k), np.log(self.template.pknow_dd), fill_value='extrapolate', assume_sorted=True))
self.nonlinear.PsCf(self.pt)
self.pt.setPsCfl()
if self.options['with_resum']:
self.resum.PsCf(self.pt, makeIR=True, makeQ=True, setIR=True, setPs=True, setCf=False)
if self.options['with_ap']:
self.projection.AP(self.pt, q=(self.template.qper, self.template.qpar))
self.projection.xdata(self.pt)
def combine_bias_terms_poles(self, params, nd=1e-4):
from pybird import bird
bird.np = jnp
self.pt.co.nd = nd
self.pt.setreducePslb(params, what='full')
bird.np = np
return jnp.nan_to_num(self.pt.fullPs, nan=0.0, posinf=jnp.inf, neginf=-jnp.inf)
def combine_bias_terms_poles_for_cross(self, biasX, biasY, nd=1e-4, km=(0.7, 0.7), kr=(0.25, 0.25)):
# Follows https://arxiv.org/abs/2308.06206 eq(13), except that stochastic terms are scaled by geometric means of nd and km
bird = self.pt
f = bird.f
b1X, b2X, b3X, b4X = (biasX[f'b{i}'] for i in [1, 2, 3, 4])
b1Y, b2Y, b3Y, b4Y = (biasY[f'b{i}'] for i in [1, 2, 3, 4])
kmX, kmY = km
krX, krY = kr
if bird.eft_basis in ["eftoflss", "westcoast"]:
b5X, b6X, b7X = (biasX[name]/ks**2 for name, ks in zip(["cct", "cr1", "cr2"], [kmX, krX, krX]))
b5Y, b6Y, b7Y = (biasY[name]/ks**2 for name, ks in zip(["cct", "cr1", "cr2"], [kmY, krY, krY]))
elif bird.eft_basis == 'eastcoast': # inversion of (2.23) of 2004.10607
ct0X = biasX["c0"] - f/3. * biasX["c2"] + 3/35. * f**2 * biasX["c4"]
ct2X = biasX["c2"] - 6/7. * f * biasX["c4"]
ct4X = biasX["c4"]
ct0Y = biasY["c0"] - f/3. * biasY["c2"] + 3/35. * f**2 * biasY["c4"]
ct2Y = biasY["c2"] - 6/7. * f
ct4Y = biasY["c4"]
b11 = jnp.array([b1X * b1Y, (b1X + b1Y) * f, f**2])
if bird.eft_basis in ["eftoflss", "westcoast"]:
bct = jnp.array([b1X * b5Y + b1Y * b5X, b1Y * b6X + b1X * b6Y, b1Y * b7X + b1X * b7Y, (b5X + b5Y) * f, (b6X + b6Y) * f, (b7X + b7Y) * f])
elif bird.eft_basis == 'eastcoast':
bct = - np.array([ct0X + ct0Y, f * (ct2X + ct2Y), f**2 * (ct4X + ct4Y)])
if bird.with_nnlo_counterterm:
raise NotImplementedError("PyBird cross-power spectrum with nnlo counterterm is not implemented yet.")
# if bird.eft_basis in ["eftoflss", "westcoast"]: cnnlo = 0.25 * jnp.array([b1X**2 * biasX["cr4"], b1X * biasX["cr6"]]) / kr[0]**4
# elif bird.eft_basis == "eastcoast": cnnlo = - biasX["ct"] * f**4 * jnp.array([b1X**2, 2. * b1X * f, f**2]) # these are not divided by kr^4 according to eastcoast definition; the prior is adjusted accordingly
bloop = jnp.array([1., 0.5*(b1X+b1Y), 0.5*(b2X+b2Y), 0.5*(b3X+b3Y), 0.5*(b4X+b4Y), b1X*b1Y, 0.5*(b1X*b2Y+b1Y*b2X), 0.5*(b1X*b3Y+b1Y*b3X), 0.5*(b1X*b4Y+b1Y*b4X), b2X*b2Y, 0.5*(b2X*b4Y+b2Y*b4X), b4X*b4Y])
if bird.with_stoch:
# ces in biasX and biasY refer to the same jnp object
bst = jnp.array([biasX["ce0"], biasX["ce1"] / (km[0] * km[1]), biasX["ce2"] / (km[0] * km[1])]) / nd
Ps = [None] * 3
Ps[0] = jnp.einsum('b,lbx->lx', b11, bird.P11l)
Ps[1] = jnp.einsum('b,lbx->lx', bloop, bird.Ploopl) + jnp.einsum('b,lbx->lx', bct, bird.Pctl)
if bird.with_stoch: Ps[1] += jnp.einsum('b,lbx->lx', bst, bird.Pstl)
# if bird.with_nnlo_counterterm: Ps[2] = jnp.einsum('b,lbx->lx', cnnlo, bird.Pnnlol)
if Ps[2] is None:
Ps[2] = jnp.zeros_like(Ps[0])
Ps = jnp.array(Ps)
fullPs = jnp.sum(Ps, axis=0)
return jnp.nan_to_num(fullPs, nan=0.0, posinf=jnp.inf, neginf=-jnp.inf)
def __getstate__(self):
state = {}
for name in ['k', 'z', 'ells', 'km', 'kr']:
if hasattr(self, name):
state[name] = getattr(self, name)
for name in self._pt_attrs:
if hasattr(self.pt, name):
state[name] = getattr(self.pt, name)
return state
def __setstate__(self, state):
for name in ['k', 'z', 'ells', 'km', 'kr']:
if name in state: setattr(self, name, state.pop(name))
from pybird import bird
self.pt = bird.Bird.__new__(bird.Bird)
self.pt.with_bias = False
self.pt.__dict__.update(state)
@classmethod
def install(cls, installer):
installer.pip('git+https://github.com/pierrexyz/pybird')
[docs]
class PyBirdTracerPowerSpectrumMultipoles(BaseTracerPowerSpectrumMultipoles):
"""
Pybird tracer power spectrum multipoles.
Can be exactly marginalized over counter terms and stochastic parameters c* and bias term b3*.
For the matter (unbiased) power spectrum, set b1=1, b2=1, b3=1 (eft_basis='eftoflss') and all other bias parameters to 0.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
shotnoise : float, default=1e4
Shot noise (which is usually marginalized over).
**kwargs : dict
Pybird options, defaults to: ``with_nnlo_higher_derivative=False, with_nnlo_counterterm=False, with_stoch=True, with_resum='full'``.
Reference
---------
- https://arxiv.org/abs/2003.07956
- https://github.com/pierrexyz/pybird
"""
_default_options = dict(with_nnlo_counterterm=False, with_stoch=True, eft_basis=None, freedom=None, shotnoise=1e4)
_deterministic_bias_params = ['b1', 'b2', 'b3', 'b4', 'bs', 'b2p4', 'b2m4', 'b2t', 'b2g', 'b3g', 'cct', 'cr1', 'cr2', 'cr4', 'cr6', 'c0', 'c2', 'c4', 'ct']
_stochastic_bias_params = ['ce0', 'ce1', 'ce2']
_with_cross = True
@classmethod
def _params(cls, params, freedom=None, tracers=None):
fix = []
if freedom in ['min', 'max']:
for param in params.select(basename=['b1']):
param.update(prior=dict(limits=[0., 4.]))
for param in params.select(basename=['b4']):
param.update(prior=dict(limits=[-15., 15.]))
for param in params.select(basename=['b2', 'b3', 'bs', 'b2p4', 'b2m4', 'b2t', 'b2g', 'b3g', 'c*']):
param.update(prior=None)
if freedom == 'max':
for param in params.select(basename=['b1', 'b2', 'b3', 'b4', 'bs', 'b2p4', 'b2m4', 'b2t', 'b2g', 'b3g']):
param.update(fixed=False)
fix += ['ce1']
if freedom == 'min':
fix += ['b2', 'b3', 'ce1']
for param in params.select(basename=fix):
param.update(value=0., fixed=True)
return BaseTracerTwoPointTheory._params.__func__(cls, params, tracers=tracers)
def set_params(self):
freedom = self.options.get('freedom', None)
if self.options['eft_basis'] is None:
self.options['eft_basis'] = 'eftoflss' if freedom == 'min' else 'westcoast'
allowed_eft_basis = ['eftoflss', 'velocileptors', 'eastcoast', 'westcoast']
if self.options['eft_basis'] not in allowed_eft_basis:
raise ValueError('eft_basis must be one of {}'.format(allowed_eft_basis))
if freedom == 'min' and self.options['eft_basis'] != 'eftoflss':
raise ValueError('freedom = "min" only defined in eft_basis = "eftoflss"')
# in pybird:
# - westcoast: c2, c4 are b2p4, b2m4
# - eastcoast: b2t, b2g, b3g are bt2, bG2, bGamma3
if self.options['eft_basis'] == 'eftoflss':
self.required_bias_params = ['b1', 'b2', 'b3', 'b4']
if self.options['eft_basis'] == 'velocileptors':
self.required_bias_params = ['b1', 'b2', 'bs', 'b3']
if self.options['eft_basis'] == 'westcoast':
self.required_bias_params = ['b1', 'b2p4', 'b3', 'b2m4']
if self.options['eft_basis'] == 'eastcoast':
self.required_bias_params = ['b1', 'b2t', 'b2g', 'b3g']
self.pt.init.update(eft_basis=self.options['eft_basis'])
# now EFT parameters
if self.options['eft_basis'] in ['eftoflss', 'velocileptors', 'westcoast']:
self.required_bias_params += ['cct', 'cr1', 'cr2']
if self.options['with_nnlo_counterterm']: self.required_bias_params += ['cr4', 'cr6']
else:
self.required_bias_params += ['c0', 'c2', 'c4']
if self.options['with_nnlo_counterterm']: self.required_bias_params += ['ct']
# now shotnoise
if self.options['with_stoch']:
self.required_bias_params += ['ce0', 'ce1', 'ce2']
default_values = {'b1': 1.6}
self.required_bias_params = {name: default_values.get(name, 0.) for name in self.required_bias_params}
self.deterministic_bias_params = [param for param in self._deterministic_bias_params if param in self.required_bias_params]
self.stochastic_bias_params = [param for param in self._stochastic_bias_params if param in self.required_bias_params]
BaseTracerPowerSpectrumMultipoles.set_params(self, pt_params=[]) # not super, for PyBirdTracerCorrelationFunctionMultipoles
fix = []
if 4 not in self.ells: fix += ['cr2', 'c4']
if 2 not in self.ells: fix += ['cr1', 'c2', 'ce2']
for param in self.init.params.select(basename=fix):
param.update(value=0., fixed=True)
def transform_params(self, **params):
if self.options['eft_basis'] == 'westcoast':
b2p4, b2m4 = [params.pop(name) for name in ['b2p4', 'b2m4']]
params['b2'] = (b2p4 + b2m4) / 2.**0.5
params['b4'] = (b2p4 - b2m4) / 2.**0.5
elif self.options['eft_basis'] == 'eastcoast':
b2g, b2t, b3g = [params.pop(name) for name in ['b2g', 'b2t', 'b3g']]
params['b2'] = params['b1'] + 7. / 2. * b2g
params['b3'] = params['b1'] + 15. * b2g + 6. * b3g
params['b4'] = 1 / 2. * b2t - 7. / 2. * b2g
elif self.options['eft_basis'] == 'velocileptors':
b1v, b2v, bsv, b3v = [params.pop(name) for name in ['b1', 'b2', 'bs', 'b3']]
params['b1'] = b1v # + 1 - 1
params['b2'] = 1. + 7. / 2. * bsv
params['b3'] = 21. / 882. * (42. - 145. * b1v - 21. * b3v + 630. * bsv)
params['b4'] = (params['b1'] - 1.) + b2v / 2.
if self.options['freedom'] == 'min':
params['b2'] = 1.
params['b3'] = (294. - 1015. * (params['b1'] - 1.)) / 441.
return params
def calculate(self, **params):
super(PyBirdTracerPowerSpectrumMultipoles, self).calculate()
params = self.pack_input_bias_params(params)
if self.is_cross_correlation():
paramsX, paramsY = {}, {}
for k, v in params.items():
if utils.is_sequence(v):
paramsX[k], paramsY[k] = v
else:
paramsX[k] = paramsY[k] = v # stochastic terms
paramsX, paramsY = self.transform_params(**paramsX), self.transform_params(**paramsY)
self.power = self.pt.combine_bias_terms_poles_for_cross(paramsX, paramsY, nd=self.nd, km=self.pt.km, kr=self.pt.kr)
else:
params = {k: v[0] if isinstance(v, tuple) else v for k, v in params.items()}
self.power = self.pt.combine_bias_terms_poles(self.transform_params(**params), nd=self.nd)
[docs]
class PyBirdCorrelationFunctionMultipoles(BasePTCorrelationFunctionMultipoles):
_default_options = dict(km=0.7, kr=0.25, accboost=1, fftaccboost=1, fftbias=-1.6, with_nnlo_counterterm=False, with_stoch=False, with_resum='full', with_ap=True, eft_basis='eftoflss')
_klim = (1e-3, 11., 3000) # numerical instability in pybird's fftlog at 10.
_pt_attrs = ['co', 'f', 'eft_basis', 'with_stoch', 'with_nnlo_counterterm', 'with_tidal_alignments',
'P11l', 'Ploopl', 'Pctl', 'Pstl', 'Pnnlol', 'C11l', 'Cloopl', 'Cctl', 'Cstl', 'Cnnlol']
def initialize(self, *args, **kwargs):
super(PyBirdCorrelationFunctionMultipoles, self).initialize(*args, **kwargs)
from pybird.common import Common
from pybird.nonlinear import NonLinear
from pybird.nnlo import NNLO_counterterm
from pybird.resum import Resum
from pybird.projection import Projection
eft_basis = self.options.get('eft_basis', None)
if eft_basis in [None, 'velocileptors']: eft_basis = 'eftoflss'
# nd used by combine_bias_terms_poles only
for name in ['km', 'kr']:
self.options[name] = self.options[name] if utils.is_sequence(self.options[name]) else (self.options[name],) * 2
self.co = Common(Nl=len(self.ells), kmin=1e-3, kmax=0.25, km=min(self.options['km']), kr=min(self.options['kr']), nd=1e-4,
eft_basis=eft_basis, halohalo=True, with_cf=True,
with_time=True, accboost=float(self.options['accboost']), optiresum=self.options['with_resum'] == 'opti', with_uvmatch=False,
exact_time=False, quintessence=False, with_tidal_alignments=False, nonequaltime=False, keep_loop_pieces_independent=False)
#print(dict(Nl=len(self.ells), kmin=1e-3, kmax=0.25, km=self.options['km'], kr=self.options['kr'], nd=1e-4,
# eft_basis=eft_basis, halohalo=True, with_cf=True,
# with_time=True, accboost=float(self.options['accboost']), optiresum=self.options['with_resum'] == 'opti', with_uvmatch=False,
# exact_time=False, quintessence=False, with_tidal_alignments=False, nonequaltime=False, keep_loop_pieces_independent=False))
self.nonlinear = NonLinear(load=False, save=False, NFFT=256 * int(self.options['fftaccboost']), fftbias=self.options['fftbias'], co=self.co) # NFFT=256, fftbias=-1.6
#print(dict(load=False, save=False, NFFT=256 * int(self.options['fftaccboost']), fftbias=self.options['fftbias'], co=self.co))
self.resum = Resum(co=self.co) # LambdaIR=.2, NFFT=192
self.nnlo_counterterm = None
if self.options['with_nnlo_counterterm']:
self.nnlo_counterterm = NNLO_counterterm(co=self.co)
self.template.init.update(with_now='peakaverage')
self.projection = Projection(self.s, with_ap=self.options['with_ap'], H_fid=None, D_fid=None, co=self.co) # placeholders for H_fid and D_fid, as we will provide q's
def calculate(self):
super(PyBirdCorrelationFunctionMultipoles, self).calculate()
from pybird.bird import Bird
cosmo = {'kk': self.template.k, 'pk_lin': self.template.pk_dd, 'pk_lin_2': None, 'f': self.template.f, 'DA': 1., 'H': 1.}
self.pt = Bird(cosmo, with_bias=False, eft_basis=self.co.eft_basis, with_stoch=self.options['with_stoch'], with_nnlo_counterterm=self.nnlo_counterterm is not None, co=self.co)
#print(dict(with_bias=False, eft_basis=self.co.eft_basis, with_stoch=self.options['with_stoch'], with_nnlo_counterterm=self.nnlo_counterterm is not None, co=self.co))
if self.nnlo_counterterm is not None: # we use smooth power spectrum since we don't want spurious BAO signals
from scipy import interpolate
self.nnlo_counterterm.Cf(self.pt, interpolate.interp1d(np.log(self.template.k), np.log(self.template.pknow_dd), fill_value='extrapolate', assume_sorted=True))
self.nonlinear.PsCf(self.pt)
self.pt.setPsCfl()
if self.options['with_resum']:
self.resum.PsCf(self.pt, makeIR=True, makeQ=True, setIR=True, setPs=True, setCf=True)
if self.options['with_ap']:
self.projection.AP(self.pt, q=(self.template.qper, self.template.qpar))
self.projection.xdata(self.pt)
def combine_bias_terms_poles(self, params, nd=1e-4):
from pybird import bird
bird.np = jnp
self.pt.co.nd = nd
self.pt.setreduceCflb(params, what='full')
bird.np = np
return self.pt.fullCf
def __getstate__(self):
state = {}
for name in ['s', 'z', 'ells']:
if hasattr(self, name):
state[name] = getattr(self, name)
for name in self._pt_attrs:
if hasattr(self.pt, name):
state[name] = getattr(self.pt, name)
return state
def __setstate__(self, state):
for name in ['s', 'z', 'ells']:
if name in state: setattr(self, name, state.pop(name))
from pybird import bird
self.pt = bird.Bird.__new__(bird.Bird)
self.pt.with_bias = False
self.pt.__dict__.update(state)
@classmethod
def install(cls, installer):
installer.pip('git+https://github.com/pierrexyz/pybird')
[docs]
class PyBirdTracerCorrelationFunctionMultipoles(BaseTracerCorrelationFunctionMultipoles):
"""
Pybird tracer correlation function multipoles.
Can be exactly marginalized over counter terms and stochastic parameters c* and bias term b3*.
For the matter (unbiased) correlation function, set b1=1, b2=1, b3=1 (eft_basis='eftoflss') and all other bias parameters to 0.
Parameters
----------
s : array, default=None
Theory separations where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
**kwargs : dict
Pybird options, defaults to: ``with_nnlo_higher_derivative=False, with_nnlo_counterterm=False, with_stoch=False, with_resum='full'``.
"""
_default_options = dict(with_nnlo_counterterm=False, with_stoch=False, eft_basis=None, freedom=None)
_params = classmethod(PyBirdTracerPowerSpectrumMultipoles._params.__func__)
set_params = PyBirdTracerPowerSpectrumMultipoles.set_params
transform_params = PyBirdTracerPowerSpectrumMultipoles.transform_params
_deterministic_bias_params = PyBirdTracerPowerSpectrumMultipoles._deterministic_bias_params
_stochastic_bias_params = []
def calculate(self, **params):
super(PyBirdTracerCorrelationFunctionMultipoles, self).calculate()
params = self.pack_input_bias_params(params)
params = {name: value[0] if isinstance(value, tuple) else value for name, value in params.items()}
self.corr = self.pt.combine_bias_terms_poles(self.transform_params(**params))
class Namespace(object):
def __init__(self, **kwargs):
self.update(**kwargs)
def update(self, **kwargs):
self.__dict__.update(**kwargs)
@jit
def folps_combine_bias_terms_pkmu(k, mu, jac, f0, table, table_now, sigma2t, pars, nd=1e-4):
import FOLPSnu as FOLPS
pars = list(pars) + [1. / nd] # add shot noise
b1 = pars[0]
# Add co-evolution part
pars[2] = pars[2] - 4. / 7. * (b1 - 1.) # bs
pars[3] = pars[3] + 32. / 315. * (b1 - 1.) # b3
FOLPS.f0 = f0
fk = table[1] * f0
pkl, pkl_now, sigma2t = table[0], table_now[0], sigma2t
pkmu = jac * ((b1 + fk * mu**2)**2 * (pkl_now + jnp.exp(-k**2 * sigma2t)*(pkl - pkl_now)*(1 + k**2 * sigma2t))
+ jnp.exp(-k**2 * sigma2t) * FOLPS.PEFTs(k, mu, pars, table)
+ (1 - jnp.exp(-k**2 * sigma2t)) * FOLPS.PEFTs(k, mu, pars, table_now))
return pkmu
[docs]
class FOLPSPowerSpectrumMultipoles(BasePTPowerSpectrumMultipoles, BaseTheoryPowerSpectrumMultipolesFromWedges):
_default_options = dict(kernels='fk')
_pt_attrs = ['kap', 'muap', 'table', 'table_now', 'sigma2t', 'f0', 'jac']
def initialize(self, *args, mu=6, **kwargs):
super(FOLPSPowerSpectrumMultipoles, self).initialize(*args, mu=mu, method='leggauss', **kwargs)
import FOLPSnu as FOLPS
FOLPS.Matrices()
self.matrices = Namespace(**{name: getattr(FOLPS, name) for name in ['M22matrices', 'M13vectors', 'bnu_b', 'N']})
def calculate(self):
super(FOLPSPowerSpectrumMultipoles, self).calculate()
import FOLPSnu as FOLPS
FOLPS.__dict__.update(self.matrices.__dict__)
# [z, omega_b, omega_cdm, omega_ncdm, h]
# only used for neutrinos
# sensitive to omega_b + omega_cdm, not omega_b, omega_cdm separately
cosmo_params = [self.z, 0.022, 0.12, 0., 0.7]
cosmo = getattr(self.template, 'cosmo', None)
if cosmo is not None:
cosmo_params = [self.z, cosmo['omega_b'], cosmo['omega_cdm'], cosmo['omega_ncdm_tot'], cosmo['h']]
FOLPS.NonLinear([self.template.k, self.template.pk_dd], cosmo_params, kminout=self.k[0] * 0.7, kmaxout=self.k[-1] * 1.3, nk=max(len(self.k), 120),
EdSkernels=self.options['kernels'] == 'eds')
#FOLPS.NonLinear([self.template.k, self.template.pk_dd], cosmo_params, kminout=0.001, kmaxout=0.5, nk=120,
# EdSkernels=self.options['kernels'] == 'eds')
k = FOLPS.kTout
jac, kap, muap = self.template.ap_k_mu(self.k, self.mu)
FOLPS.f0 = f0 = self.template.f0 # for Sigma2Total
table = FOLPS.Table_interp(kap, k, FOLPS.TableOut_interp(k))
table_now = FOLPS.TableOut_NW_interp(k)
sigma2t = FOLPS.Sigma2Total(k, muap, table_now)
table_now = FOLPS.Table_interp(kap, k, table_now)
self.pt = Namespace(kap=kap, muap=muap, table=table, table_now=table_now, sigma2t=sigma2t, f0=f0, jac=jac)
self.sigma8 = self.template.sigma8
self.fsigma8 = self.template.f * self.sigma8
def combine_bias_terms_poles(self, pars, nd=1e-4):
return self.to_poles(folps_combine_bias_terms_pkmu(self.pt.kap, self.pt.muap, self.pt.jac, self.pt.f0,
self.pt.table, self.pt.table_now, self.pt.sigma2t, pars, nd=nd))
def __getstate__(self):
state = {}
for name in ['k', 'z', 'ells', 'wmu', 'sigma8', 'fsigma8']:
if hasattr(self, name):
state[name] = getattr(self, name)
for name in self._pt_attrs:
if hasattr(self.pt, name):
state[name] = getattr(self.pt, name)
return state
def __setstate__(self, state):
for name in ['k', 'z', 'ells', 'wmu', 'sigma8', 'fsigma8']:
if name in state: setattr(self, name, state.pop(name))
self.pt = Namespace(**state)
@classmethod
def install(cls, installer):
installer.pip('git+https://github.com/henoriega/FOLPS-nu')
[docs]
class FOLPSTracerPowerSpectrumMultipoles(BaseTracerPowerSpectrumMultipoles):
r"""
FOLPS tracer power spectrum multipoles.
Can be exactly marginalized over counter terms and stochastic parameters alpha*, sn* and bias term b3*.
By default, bs and b3 are fixed to 0, following co-evolution.
For the matter (unbiased) power spectrum, set b1=1 and all other bias parameters to 0.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
shotnoise : float, default=1e4
Shot noise (which is usually marginalized over).
prior_basis : str, default='physical'
If 'physical', use physically-motivated prior basis for bias parameters, counterterms and stochastic terms:
:math:`b_{1}^\prime = (1 + b_{1}^{L}) \sigma_{8}(z), b_{2}^\prime = b_{2}^{L} \sigma_{8}(z)^2, b_{s}^\prime = b_{s}^{L} \sigma_{8}(z)^2, b_{3}^\prime = 0`
with: :math:`b_{1} = 1 + b_{1}^{L}, b_{2} = 8/21 b_{1}^{L} + b_{2}^{L}, b_{s} = -4/7 b_{1}^{L} + b_{s}^{L}`.
:math:`\alpha_{0} = (1 + b_{1}^{L})^{2} \alpha_{0}^\prime, \alpha_{2} = f (1 + b_{1}^{L}) (\alpha_{0}^\prime + \alpha_{2}^\prime), \alpha_{4} = f (f \alpha_{2}^\prime + (1 + b_{1}^{L}) \alpha_{4}^\prime)`.
:math:`s_{n, 0} = f_{\mathrm{sat}}/\bar{n} s_{n, 0}^\prime, s_{n, 2} = f_{\mathrm{sat}}/\bar{n} \sigma_{v}^{2} s_{n, 2}^\prime, s_{n, 4} = f_{\mathrm{sat}}/\bar{n} \sigma_{v}^{4} s_{n, 4}^\prime`.
tracer : str, default=None
If ``prior_basis = 'physical'``, tracer to load preset ``fsat`` and ``sigv``. One of ['LRG', 'ELG', 'QSO'].
fsat : float, default=None
If ``prior_basis = 'physical'``, satellite fraction to assume.
sigv : float, default=None
If ``prior_basis = 'physical'``, velocity dispersion to assume.
Reference
---------
- https://arxiv.org/abs/2208.02791
- https://github.com/henoriega/FOLPS-nu
"""
_default_options = dict(freedom=None, prior_basis='physical', tracer=None, fsat=None, sigv=None, shotnoise=1e4)
_deterministic_bias_params = ['b1', 'b2', 'bs', 'b3', 'alpha0', 'alpha2', 'alpha4', 'ct']
_stochastic_bias_params = ['sn0', 'sn2', 'sn4']
@classmethod
def _params(cls, params, freedom=None, prior_basis='physical', tracers=None):
fix = []
if freedom in ['min', 'max']:
for param in params.select(basename=['b1']):
param.update(prior=dict(limits=[0., 10.]))
for param in params.select(basename=['b2']):
param.update(prior=dict(limits=[-50., 50.]))
for param in params.select(basename=['bs', 'b3', 'alpha*', 'sn*']):
param.update(prior=None)
if freedom == 'max':
for param in params.select(basename=['b1', 'b2', 'bs', 'b3']):
param.update(fixed=False)
fix += ['ct']
if freedom == 'min':
fix += ['b3', 'bs', 'ct']
for param in params.select(basename=fix):
param.update(value=0., fixed=True)
params = BaseTracerTwoPointTheory._params.__func__(cls, params, tracers=tracers)
if prior_basis == 'physical':
for param in list(params):
basename = param.basename
param.update(basename=basename + 'p')
#params.set({'basename': basename, 'namespace': param.namespace, 'derived': True})
for param in params.select(basename='b1p'):
param.update(prior=dict(dist='uniform', limits=[0., 3.]), ref=dict(dist='norm', loc=1., scale=0.1))
for param in params.select(basename=['b2p', 'bsp', 'b3p']):
param.update(prior=dict(dist='norm', loc=0., scale=5.), ref=dict(dist='norm', loc=0., scale=1.))
for param in params.select(basename='b3p'):
param.update(value=0., fixed=True)
for param in params.select(basename='alpha*p'):
param.update(prior=dict(dist='norm', loc=0., scale=12.5), ref=dict(dist='norm', loc=0., scale=1.)) # 50% at k = 0.2 h/Mpc
for param in params.select(basename='sn*p'):
param.update(prior=dict(dist='norm', loc=0., scale=2. if 'sn0' in param.basename else 5.), ref=dict(dist='norm', loc=0., scale=1.))
return params
def set_params(self):
self.required_bias_params = ['b1', 'b2', 'bs', 'b3', 'alpha0', 'alpha2', 'alpha4', 'ct', 'sn0', 'sn2']
self.stochastic_bias_params = [param for param in self.stochastic_bias_params if param in self.required_bias_params]
default_values = {'b1': 2.}
self.required_bias_params = {name: default_values.get(name, 0.) for name in self.required_bias_params}
self.is_physical_prior = self.options['prior_basis'] == 'physical'
if self.is_physical_prior:
for name in list(self.required_bias_params):
self.required_bias_params[name + 'p'] = self.required_bias_params.pop(name)
settings = get_physical_stochastic_settings(tracer=self.options['tracer'])
for name, value in settings.items():
if self.options[name] is None: self.options[name] = value
if self.mpicomm.rank == 0:
self.log_debug('Using fsat, sigv = {:.3f}, {:.3f}.'.format(self.options['fsat'], self.options['sigv']))
self.deterministic_bias_params = [name + 'p' for name in self.deterministic_bias_params]
self.stochastic_bias_params = [name + 'p' for name in self.stochastic_bias_params]
super().set_params(pt_params=[])
fix = []
if 4 not in self.ells: fix += ['alpha4']
if 2 not in self.ells: fix += ['alpha2', 'sn2']
for param in self.init.params.select(basename=fix):
param.update(value=0., fixed=True)
self.nd = 1e-4
self.fsat = self.snd = 1.
if self.is_physical_prior:
self.fsat, self.snd = self.options['fsat'], self.options['shotnoise'] * self.nd # normalized by 1e-4
def calculate(self, **params):
super(FOLPSTracerPowerSpectrumMultipoles, self).calculate()
params = self.pack_input_bias_params(params)
params = {name: value[0] if isinstance(value, tuple) else value for name, value in params.items()}
if self.is_physical_prior:
sigma8 = self.pt.sigma8
f = self.pt.fsigma8 / sigma8
# b1E = b1L + 1
# b2E = 8/21 * b1L + b2L
# bsE = -4/7 b1L + bsL
b1L, b2L, bsL, b3 = params['b1p'] / sigma8 - 1., params['b2p'] / sigma8**2, params['bsp'] / sigma8**2, params['b3p']
pars = [1. + b1L, b2L + 8. / 21. * b1L, bsL, b3] # compensate bs by 4. / 7. * b1L as it is removed by combine_bias_terms_poles below
pars += [(1 + b1L)**2 * params['alpha0p'], f * (1 + b1L) * (params['alpha0p'] + params['alpha2p']),
f * (f * params['alpha2p'] + (1 + b1L) * params['alpha4p']), 0.]
sigv = self.options['sigv']
pars += [params['sn{:d}p'.format(i)] * self.snd * (self.fsat if i > 0 else 1.) * sigv**i for i in [0, 2]]
else:
pars = [params[name] for name in self.required_bias_params]
#self.__dict__.update(dict(zip(['b1', 'b2', 'bs', 'b3', 'alpha0', 'alpha2', 'alpha4', 'alpha6', 'sn0', 'sn2'], pars))) # for derived parameters
opts = {name: params.get(name, default) for name, default in self.optional_bias_params.items()}
self.power = self.pt.combine_bias_terms_poles(pars, **opts, nd=self.nd)
[docs]
class FOLPSTracerCorrelationFunctionMultipoles(BaseTracerCorrelationFunctionFromPowerSpectrumMultipoles):
r"""
FOLPS tracer correlation function multipoles.
Can be exactly marginalized over counter terms and stochastic parameters alpha*, sn* and bias term b3*.
By default, bs and b3 are fixed to 0, following co-evolution.
For the matter (unbiased) correlation function, set b1=1 and all other bias parameters to 0.
Parameters
----------
s : array, default=None
Theory separations where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
prior_basis : str, default='physical'
:math:`b_{1}^\prime = (1 + b_{1}^{L}) \sigma_{8}(z), b_{2}^\prime = b_{2}^{L} \sigma_{8}(z)^2, b_{s}^\prime = b_{s}^{L} \sigma_{8}(z)^2, b_{3}^\prime = 0`
with: :math:`b_{1} = 1 + b_{1}^{L}, b_{2} = 8/21 b_{1}^{L} + b_{2}^{L}, b_{s} = -4/7 b_{1}^{L} + b_{s}^{L}`.
:math:`\alpha_{0} = (1 + b_{1}^{L})^{2} \alpha_{0}^\prime, \alpha_{2} = f (1 + b_{1}^{L}) (\alpha_{0}^\prime + \alpha_{2}^\prime), \alpha_{4} = f (f \alpha_{2}^\prime + (1 + b_{1}^{L}) \alpha_{4}^\prime)`.
Reference
---------
- https://arxiv.org/abs/2208.02791
- https://github.com/cosmodesi/folpsax
"""
_params = classmethod(FOLPSTracerPowerSpectrumMultipoles._params.__func__)
_deterministic_bias_params = FOLPSTracerPowerSpectrumMultipoles._deterministic_bias_params
_stochastic_bias_params = []
def set_params(self):
super().set_params()
if self.power.is_physical_prior:
self.deterministic_bias_params = [name + 'p' for name in self.deterministic_bias_params]
self.stochastic_bias_params = [name + 'p' for name in self.stochastic_bias_params]
[docs]
class FOLPSAXPowerSpectrumMultipoles(BasePTPowerSpectrumMultipoles, BaseTheoryPowerSpectrumMultipolesFromWedges):
_default_options = dict(kernels='fk', rbao=104.)
_pt_attrs = ['jac', 'kap', 'muap', 'table', 'table_now', 'scalars', 'scalars_now']
def initialize(self, *args, mu=6, **kwargs):
super(FOLPSAXPowerSpectrumMultipoles, self).initialize(*args, mu=mu, method='leggauss', **kwargs)
from folpsax import get_mmatrices
self.matrices = get_mmatrices()
self.template.init.update(with_now='peakaverage')
def calculate(self):
super(FOLPSAXPowerSpectrumMultipoles, self).calculate()
# [z, omega_b, omega_cdm, omega_ncdm, h]
# only used for neutrinos
# sensitive to omega_b + omega_cdm, not omega_b, omega_cdm separately
#cosmo_params = {'z': self.z, 'fnu': 0., 'Omega_m': 0.3, 'h': 0.7}
#cosmo = getattr(self.template, 'cosmo', None)
#if cosmo is not None:
# cosmo_params['fnu'] = cosmo['Omega_ncdm_tot'] / cosmo['Omega_m']
# cosmo_params['Omega_m'] = cosmo['Omega_m']
# cosmo_params['h'] = cosmo['h']
# cosmo_params['Nnu'] = cosmo['N_ncdm']
# cosmo_params['Neff'] = cosmo['N_eff']
#cosmo_params['f0'] = self.template.f0
cosmo_params = {}
cosmo_params['pkttlin'] = self.template.pk_dd * self.template.fk**2
if getattr(self, '_get_non_linear', None) is None:
from folpsax import get_non_linear
def _get_non_linear(pk_dd, pknow_dd, **cosmo_params):
return get_non_linear(self.template.k, pk_dd, self.matrices, pknow=pknow_dd,
kminout=self.k[0] * 0.7, kmaxout=self.k[-1] * 1.3, nk=max(len(self.k), 150),
kernels=self.options['kernels'], rbao=self.options['rbao'], **cosmo_params)
self._get_non_linear = jit(_get_non_linear)
table, table_now = self._get_non_linear(self.template.pk_dd, self.template.pknow_dd, **cosmo_params)
jac, kap, muap = self.template.ap_k_mu(self.k, self.mu)
self.pt = Namespace(jac=jac, kap=kap, muap=muap, table=table[1:26], table_now=table_now[1:26], scalars=table[26:], scalars_now=table_now[26:])
self.kt = table[0]
self.sigma8 = self.template.sigma8
self.fsigma8 = self.template.f * self.sigma8
def combine_bias_terms_poles(self, pars, nd=1e-4):
table = (self.kt,) + tuple(self.pt.table) + tuple(self.pt.scalars)
table_now = (self.kt,) + tuple(self.pt.table_now) + tuple(self.pt.scalars_now)
pars = list(pars) + [1. / nd] # add shot noise
b1 = pars[0]
# add co-evolution part
pars[2] = pars[2] - 4. / 7. * (b1 - 1.) # bs
pars[3] = pars[3] + 32. / 315. * (b1 - 1.) # b3
ncols = len(table)
if getattr(self, '_get_poles', None) is None:
from folpsax import get_rsd_pkmu
def _get_poles(jac, kap, muap, pars, *table):
return self.to_poles(jac * get_rsd_pkmu(kap, muap, pars, table[:ncols], table[ncols:]))
self._get_poles = jit(_get_poles)
return self._get_poles(self.pt.jac, self.pt.kap, self.pt.muap, jnp.array(pars), *table, *table_now)
#pkmu = self.pt.jac * get_rsd_pkmu(self.pt.kap, self.pt.muap, pars, table, table_now)
#return self.to_poles(pkmu)
def __getstate__(self, varied=True, fixed=True):
state = {}
for name in (['k', 'z', 'ells', 'wmu', 'kt'] if fixed else []) + (['sigma8', 'fsigma8'] if varied else []):
if hasattr(self, name):
state[name] = getattr(self, name)
if varied:
for name in self._pt_attrs:
if hasattr(self.pt, name):
state[name] = getattr(self.pt, name)
return state
def __setstate__(self, state):
for name in ['k', 'z', 'ells', 'wmu', 'kt', 'sigma8', 'fsigma8']:
if name in state: setattr(self, name, state.pop(name))
if not hasattr(self, 'pt'): self.pt = Namespace()
self.pt.update(**state)
@classmethod
def install(cls, installer):
installer.pip('git+https://github.com/cosmodesi/folpsax')
[docs]
class FOLPSAXTracerPowerSpectrumMultipoles(FOLPSTracerPowerSpectrumMultipoles):
r"""
FOLPS tracer power spectrum multipoles.
Can be exactly marginalized over counter terms and stochastic parameters alpha*, sn* and bias term b3*.
By default, bs and b3 are fixed to 0, following co-evolution.
For the matter (unbiased) power spectrum, set b1=1 and all other bias parameters to 0.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
shotnoise : float, default=1e4
Shot noise (which is usually marginalized over).
prior_basis : str, default='physical'
If 'physical', use physically-motivated prior basis for bias parameters, counterterms and stochastic terms:
:math:`b_{1}^\prime = (1 + b_{1}^{L}) \sigma_{8}(z), b_{2}^\prime = b_{2}^{L} \sigma_{8}(z)^2, b_{s}^\prime = b_{s}^{L} \sigma_{8}(z)^2, b_{3}^\prime = 0`
with: :math:`b_{1} = 1 + b_{1}^{L}, b_{2} = 8/21 b_{1}^{L} + b_{2}^{L}, b_{s} = -4/7 b_{1}^{L} + b_{s}^{L}`.
:math:`\alpha_{0} = (1 + b_{1}^{L})^{2} \alpha_{0}^\prime, \alpha_{2} = f (1 + b_{1}^{L}) (\alpha_{0}^\prime + \alpha_{2}^\prime), \alpha_{4} = f (f \alpha_{2}^\prime + (1 + b_{1}^{L}) \alpha_{4}^\prime)`.
:math:`s_{n, 0} = f_{\mathrm{sat}}/\bar{n} s_{n, 0}^\prime, s_{n, 2} = f_{\mathrm{sat}}/\bar{n} \sigma_{v}^{2} s_{n, 2}^\prime, s_{n, 4} = f_{\mathrm{sat}}/\bar{n} \sigma_{v}^{4} s_{n, 4}^\prime`.
tracer : str, default=None
If ``prior_basis = 'physical'``, tracer to load preset ``fsat`` and ``sigv``. One of ['LRG', 'ELG', 'QSO'].
fsat : float, default=None
If ``prior_basis = 'physical'``, satellite fraction to assume.
sigv : float, default=None
If ``prior_basis = 'physical'``, velocity dispersion to assume.
Reference
---------
- https://arxiv.org/abs/2208.02791
- https://github.com/cosmodesi/folpsax
"""
[docs]
class FOLPSAXTracerCorrelationFunctionMultipoles(BaseTracerCorrelationFunctionFromPowerSpectrumMultipoles):
r"""
FOLPS tracer correlation function multipoles.
Can be exactly marginalized over counter terms and stochastic parameters alpha*, sn* and bias term b3*.
By default, bs and b3 are fixed to 0, following co-evolution.
For the matter (unbiased) correlation function, set b1=1 and all other bias parameters to 0.
Parameters
----------
s : array, default=None
Theory separations where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
prior_basis : str, default='physical'
If 'physical', use physically-motivated prior basis for bias parameters, counterterms and stochastic terms:
:math:`b_{1}^\prime = (1 + b_{1}^{L}) \sigma_{8}(z), b_{2}^\prime = b_{2}^{L} \sigma_{8}(z)^2, b_{s}^\prime = b_{s}^{L} \sigma_{8}(z)^2, b_{3}^\prime = 0`
with: :math:`b_{1} = 1 + b_{1}^{L}, b_{2} = 8/21 b_{1}^{L} + b_{2}^{L}, b_{s} = -4/7 b_{1}^{L} + b_{s}^{L}`.
:math:`\alpha_{0} = (1 + b_{1}^{L})^{2} \alpha_{0}^\prime, \alpha_{2} = f (1 + b_{1}^{L}) (\alpha_{0}^\prime + \alpha_{2}^\prime), \alpha_{4} = f (f \alpha_{2}^\prime + (1 + b_{1}^{L}) \alpha_{4}^\prime)`.
Reference
---------
- https://arxiv.org/abs/2208.02791
- https://github.com/cosmodesi/folpsax
"""
_params = classmethod(FOLPSAXTracerPowerSpectrumMultipoles._params.__func__)
def set_params(self):
super().set_params()
if self.power.is_physical_prior:
self.deterministic_bias_params = [name + 'p' for name in self.deterministic_bias_params]
self.stochastic_bias_params = [name + 'p' for name in self.stochastic_bias_params]
def pt_kernel(k, q, wq):
jq = q**2 * wq / (4. * np.pi**2)
k = k[:, None]
x = q / k
# Integral of F3(q, -q, k) over mu cosine angle between k and q
def kernel_ff(x):
x = np.array(x)
toret = (6. / x**2 - 79. + 50. * x**2 - 21. * x**4 + 0.75 * (1. / x - x)**3 * (2. + 7. * x**2) * 2 * np.log(np.abs((x - 1.) / (x + 1.)))) / 504.
mask = x > 10.
toret[mask] = - 61. / 630. + 2. / 105. / x[mask]**2 - 10. / 1323. / x[mask]**4
dx = x - 1.
mask = np.abs(dx) < 0.01
toret[mask] = - 11. / 126. + dx[mask] / 126. - 29. / 252. * dx[mask]**2
return toret / x**2
return 2 * jq * kernel_ff(x)
@jit
def pt_pk_1loop(k, q, wq, pk_q, kernel13_d):
# We could have a speed-up with FFTlog, see https://arxiv.org/pdf/1603.04405.pdf
k11 = k
k = k[:, None]
jq = q**2 * wq / (4. * np.pi**2)
mus, wmus = utils.weights_mu(10, method='leggauss')
# Compute P22
pk_k = jnp.interp(k11, q, pk_q)
def get_pk22_dd(mu, wmu):
kdq = k * q * mu # k \cdot q
kq2 = k**2 - 2. * kdq + q**2 # |k - q|^2
qdkq = kdq - q**2 # k \cdot (k - q)
F2_d = 5. / 7. + 1. / 2. * qdkq * (1. / q**2 + 1. / kq2) + 2. / 7. * qdkq**2 / (q**2 * kq2)
pk_kq = jnp.interp(kq2**0.5, q, pk_q, left=0., right=0.)
jq_pk_q_pk_kq = jq * pk_q * pk_kq
return 2 * wmu * jnp.sum(F2_d**2 * jq_pk_q_pk_kq, axis=-1)
pk22_dd = jnp.sum(jax.vmap(get_pk22_dd)(mus, wmus), axis=0)
pk11 = pk_k
pk13_dd = 2. * jnp.sum(kernel13_d * pk_q, axis=-1) * pk_k
pk_dd = pk11 + pk22_dd + pk13_dd
return pk_dd
[docs]
class GeoFPTAXTracerBispectrumMultipoles(BaseTracerThreePointTheory):
r"""
GeoFPTAX bispectrum multipoles.
Can be exactly marginalized over stochastic parameters sn*.
For the matter (unbiased) power spectrum, set b1=1 and all other bias parameters to 0.
Parameters
----------
k : tuple of arrays, default=None
Triangles of wavenumbers of shape (nk, 3) where to evaluate multipoles.
ells : tuple, default=((0, 0, 0), (2, 0, 0), (0, 2, 0), (0, 0, 2))
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
pt : str, default=None
Order of :math:`P(k)` fed into the bispectrum calculation.
If ``None``, linear :math:`P(k)`.
If '1loop', use 1-loop standard PT.
shotnoise : array, default=1e4
Shot noise for each of the multipoles. Same length as ``k``.
prior_basis : str, default='physical'
If 'physical', use physically-motivated prior basis for bias parameters:
:math:`b_{1}^\prime = (b_{1}^{E}) \sigma_{8}(z), b_{2}^\prime = b_{2}^{E} \sigma_{8}(z)^2
Reference
---------
- https://arxiv.org/pdf/2303.15510v1
- https://github.com/dforero0896/geofptax
"""
config_fn = 'full_shape.yaml'
_klim = (1e-3, 2., 500)
_default_options = dict(prior_basis='physical', mu=50)
_deterministic_bias_params = ['b1', 'b2', 'sigmav'] # let's regard sigmav as a bias parameter (counter term) for now
_stochastic_bias_params = ['sn0']
def initialize(self, k=None, z=None, template=None, ells=((0, 0, 0), (2, 0, 0), (0, 2, 0), (0, 0, 2)), shotnoise=None, pt=None, **kwargs):
super(GeoFPTAXTracerBispectrumMultipoles, self).initialize(tracers=kwargs.pop('tracers', None))
self.options = self._default_options | dict(kwargs)
self.ells = ells
if utils.is_sequence(ells[0]):
self.ells = (ells,)
self.ells = tuple(ells)
if k is None:
# Default k-bins (k1, k2, k3) in Soccimarro basis
k = np.linspace(0.01, 0.1, 11)
k = np.meshgrid(k, k, k, indexing='ij')
k = np.column_stack([kk.ravel() for kk in k])
# Impose triangular condition
mask = (k[:, 0] <= k[:, 1] + k[:, 2]) | (k[:, 1] <= k[:, 0] + k[:, 2]) | (k[:, 2] <= k[:, 0] + k[:, 1])
k = k[mask]
if not utils.is_sequence(k):
# Tuple of k1k2k3's, one for each multipole
k = (k,) * len(self.ells)
self.k = tuple(np.array(kk, dtype='f8') for kk in k)
if shotnoise is None:
shotnoise = 0.
if not utils.is_sequence(shotnoise):
# (k-dependent) bispectrum shot-noise
shotnoise = (shotnoise,) * len(self.ells)
self.shotnoise = tuple(np.atleast_1d(sn) for sn in shotnoise)
# The input linear power spectrum (template)
if template is None:
template = DirectPowerSpectrumTemplate()
self.template = template
# k for input linear power spectrum
kin = np.geomspace(min(self._klim[0], min(kk.min() for kk in self.k) / 2, self.template.init.get('k', [1.])[0]), max(self._klim[1], max(kk.max() for kk in self.k) * 2, self.template.init.get('k', [0.])[0]), self._klim[2]) # margin for AP effect
# Ask for input k, z
self.template.init.update(k=kin)
if z is not None: self.template.init.update(z=z)
self.z = self.template.z
# Set parameters
self.set_params()
self.pt = pt
assert self.pt in [None, '1loop']
@classmethod
def _params(cls, params, prior_basis='physical', tracers=None):
params = super()._params(params, tracers=tracers)
# Prior basis is 'physical' = sampled bias bp is b * sigma8^n
if prior_basis == 'physical':
for param in list(params):
basename = param.basename
param.update(basename=basename + 'p')
#params.set({'basename': basename, 'namespace': param.namespace, 'derived': True})
for param in params.select(basename='b1p'):
param.update(prior=dict(dist='uniform', limits=[0., 3.]), ref=dict(dist='norm', loc=1., scale=0.1))
for param in params.select(basename=['b2p']):
param.update(prior=dict(dist='norm', loc=0., scale=5.), ref=dict(dist='norm', loc=0., scale=1.))
for param in params.select(basename='sn*p'):
param.update(prior=dict(dist='norm', loc=0., scale=2. if 'sn0' in param.basename else 5.), ref=dict(dist='norm', loc=0., scale=1.))
return params
def set_params(self):
# Set parameters (self.init.params)
self.required_bias_params = ['b1', 'b2', 'sigmav', 'sn0']
default_values = {'b1': 2.}
self.required_bias_params = {name: default_values.get(name, 0.) for name in self.required_bias_params}
self.optional_bias_params = {}
self.is_physical_prior = self.options['prior_basis'] == 'physical'
if self.is_physical_prior:
for name in list(self.required_bias_params):
self.required_bias_params[name + 'p'] = self.required_bias_params.pop(name)
self.deterministic_bias_params = [name + 'p' for name in self.deterministic_bias_params]
self.stochastic_bias_params = [name + 'p' for name in self.stochastic_bias_params]
fix = []
if 2 not in self.ells: fix += ['sn2']
for param in self.init.params.select(basename=fix):
param.update(value=0., fixed=True)
def calculate(self, **params):
# Calculte the bispectrum (set attribute self.power, see at the end)
self.z = self.template.z
self.sigma8 = self.template.sigma8
self.fsigma8 = self.template.f * self.sigma8
params = self.pack_input_bias_params(params)
params = {name: value[0] if isinstance(value, tuple) else value for name, value in params.items()}
pars = []
# Conversion from "physical" bias parameters to standard basis
if self.is_physical_prior:
sigma8 = self.template.sigma8
f = self.template.fsigma8 / sigma8
b1E, b2E = params['b1p'] / sigma8, params['b2p'] / sigma8**2
pars += [b1E, b2E, params['sigmavp'], params['sn0p']]
else:
pars = [params[name] for name in self.required_bias_params]
# b1, b2, A_P, sigma_P, A_B, sigma_B, *_P
pars = pars[:2] + [1., 4.] + [pars[3], pars[2]]
# Alock-Paczynski parameters are self.template.qpar, self.template.qper
all_pars = jnp.array([self.sigma8, self.fsigma8 / self.sigma8, self.template.qpar, self.template.qper] + pars)
from geofptax.kernels import bk_multip
kt = self.template.k
pkt = self.template.pk_dd # theory linear pk
if self.pt: # loop correction: update pkt with 1-loop calculation
q = kt
ktmin, ktmax = min(kk.min() for kk in self.k) * 0.7, max(kk.max() for kk in self.k) * 1.3
kt = jnp.linspace(ktmin, ktmax, self._klim[2])
wq = utils.weights_trapz(q)
if getattr(self, 'kernel', None) is None:
# Compute pt kernel the first time only
self.kernel = pt_kernel(kt, q, wq)
pkt = pt_pk_1loop(kt, q, wq, pkt, self.kernel)
# k for bk0, bk200, bk020, bk002
kk = list(self.k) + [self.k[-1]] * (4 - len(self.k))
# Compute bk multipoles
res = bk_multip(*kk, kt, pkt, all_pars, redshift=self.z, num_points=self.options['mu'])
tells = [(0, 0, 0), (2, 0, 0), (0, 2, 0), (0, 0, 2)]
res = [res[tells.index(ell)] for ell in self.ells]
# Include shot noise term, rescaling by AP (alpha_par * alpha_per**2)**2
A_B = all_pars[8] / (all_pars[2] * all_pars[3]**2)**2
res = [rr + A_B * sn for rr, sn in zip(res, self.shotnoise)]
self.power = res
[docs]
def get(self):
# Returned value when calling the calculator
return self.power
@classmethod
def install(cls, installer):
# Dependency
installer.pip('git+https://github.com/dforero0896/geofptax')
def __getstate__(self):
# Required only for quick emulation (Taylor expansion)
state = {}
for name in ['k', 'z', 'ells', 'power']:
if hasattr(self, name):
state[name] = getattr(self, name)
return state
from desilike.theories.primordial_cosmology import Cosmoprimo, get_cosmo
from .base import APEffect
_registered_legendre = [None] * 11
_registered_legendre[0] = lambda x: jnp.ones_like(x)
_registered_legendre[1] = lambda x: x
_registered_legendre[2] = lambda x: 3*x**2/2 - 1/2
_registered_legendre[3] = lambda x: 5*x**3/2 - 3*x/2
_registered_legendre[4] = lambda x: 35*x**4/8 - 15*x**2/4 + 3/8
_registered_legendre[5] = lambda x: 63*x**5/8 - 35*x**3/4 + 15*x/8
_registered_legendre[6] = lambda x: 231*x**6/16 - 315*x**4/16 + 105*x**2/16 - 5/16
_registered_legendre[7] = lambda x: 429*x**7/16 - 693*x**5/16 + 315*x**3/16 - 35*x/16
_registered_legendre[8] = lambda x: 6435*x**8/128 - 3003*x**6/32 + 3465*x**4/64 - 315*x**2/32 + 35/128
_registered_legendre[9] = lambda x: 12155*x**9/128 - 6435*x**7/32 + 9009*x**5/64 - 1155*x**3/32 + 315*x/128
_registered_legendre[10] = lambda x: 46189*x**10/256 - 109395*x**8/256 + 45045*x**6/128 - 15015*x**4/128 + 3465*x**2/256 - 63/256
def get_legendre(ell):
return _registered_legendre[ell]
[docs]
class JAXEffortTracerPowerSpectrumMultipoles(BaseTheoryPowerSpectrumMultipoles):
r"""
Wrapper to JAXEffort emulator.
Can be exactly marginalized over counter terms and stochastic parameters alpha*, sn* and bias term b3*.
By default, bs and b3 are fixed to 0, following co-evolution.
For the matter (unbiased) power spectrum, set b1=1 and all other bias parameters to 0.
Parameters
----------
k : array, default=None
Theory wavenumbers where to evaluate multipoles.
ells : tuple, default=(0, 2, 4)
Multipoles to compute.
template : BasePowerSpectrumTemplate
Power spectrum template. Defaults to :class:`DirectPowerSpectrumTemplate`.
shotnoise : float, default=1e4
Shot noise (which is usually marginalized over).
prior_basis : str, default='physical'
If 'physical', use physically-motivated prior basis for bias parameters, counterterms and stochastic terms:
:math:`b_{1}^\prime = (1 + b_{1}^{L}) \sigma_{8}(z), b_{2}^\prime = b_{2}^{L} \sigma_{8}(z)^2, b_{s}^\prime = b_{s}^{L} \sigma_{8}(z)^2, b_{3}^\prime = 0`
with: :math:`b_{1} = 1 + b_{1}^{L}, b_{2} = 8/21 b_{1}^{L} + b_{2}^{L}, b_{s} = -4/7 b_{1}^{L} + b_{s}^{L}`.
:math:`\alpha_{0} = (1 + b_{1}^{L})^{2} \alpha_{0}^\prime, \alpha_{2} = f (1 + b_{1}^{L}) (\alpha_{0}^\prime + \alpha_{2}^\prime), \alpha_{4} = f (f \alpha_{2}^\prime + (1 + b_{1}^{L}) \alpha_{4}^\prime)`.
:math:`s_{n, 0} = f_{\mathrm{sat}}/\bar{n} s_{n, 0}^\prime, s_{n, 2} = f_{\mathrm{sat}}/\bar{n} \sigma_{v}^{2} s_{n, 2}^\prime, s_{n, 4} = f_{\mathrm{sat}}/\bar{n} \sigma_{v}^{4} s_{n, 4}^\prime`.
tracer : str, default=None
If ``prior_basis = 'physical'``, tracer to load preset ``fsat`` and ``sigv``. One of ['LRG', 'ELG', 'QSO'].
fsat : float, default=None
If ``prior_basis = 'physical'``, satellite fraction to assume.
sigv : float, default=None
If ``prior_basis = 'physical'``, velocity dispersion to assume.
Reference
---------
- https://github.com/CosmologicalEmulators/jaxeffort
"""
_default_options = dict(freedom=None, prior_basis='physical', tracer=None, fsat=None, sigv=None, shotnoise=1e4)
@classmethod
def _params(cls, params, model='velocileptors_rept_mnuw0wacdm', freedom=None, prior_basis='physical', tracers=None):
from desilike.base import get_calculator_config
if 'velocileptors_lpt' in model:
params = get_calculator_config(LPTVelocileptorsTracerPowerSpectrumMultipoles)[-1]
return LPTVelocileptorsTracerPowerSpectrumMultipoles._params(params, freedom=freedom, prior_basis=prior_basis, tracers=tracers)
elif 'velocileptors_rept' in model:
params = get_calculator_config(REPTVelocileptorsTracerPowerSpectrumMultipoles)[-1]
return REPTVelocileptorsTracerPowerSpectrumMultipoles._params(params, freedom=freedom, prior_basis=prior_basis, tracers=tracers)
else:
raise NotImplementedError
def set_params(self):
if 'velocileptors' in self.model:
# FIXME (couldn't use set_params as requires pt attribute)
self.deterministic_bias_params = ['b1', 'b2', 'bs', 'b3', 'alpha0', 'alpha2', 'alpha4', 'alpha6']
self.stochastic_bias_params = ['sn0', 'sn2', 'sn4']
self.required_bias_params = {param: 0. for param in self.deterministic_bias_params + self.stochastic_bias_params}
self.required_bias_params['b1'] = 1.
self.is_physical_prior = self.options['prior_basis'] == 'physical'
if self.is_physical_prior:
for name in list(self.required_bias_params):
self.required_bias_params[name + 'p'] = self.required_bias_params.pop(name)
settings = get_physical_stochastic_settings(tracer=self.options['tracer'])
for name, value in settings.items():
if self.options[name] is None: self.options[name] = value
if self.mpicomm.rank == 0:
self.log_debug('Using fsat, sigv = {:.3f}, {:.3f}.'.format(self.options['fsat'], self.options['sigv']))
self.deterministic_bias_params = [name + 'p' for name in self.deterministic_bias_params]
self.stochastic_bias_params = [name + 'p' for name in self.stochastic_bias_params]
fix = []
if 4 not in self.ells: fix += ['alpha4*', 'alpha6*', 'sn4*'] # * to capture p
if 2 not in self.ells: fix += ['alpha2*', 'sn2*']
for param in self.init.params.select(basename=fix):
param.update(value=0., fixed=True)
self.nd = 1e-4
self.fsat = self.snd = 1.
if self.is_physical_prior:
self.fsat, self.snd = self.options['fsat'], self.options['shotnoise'] * self.nd # normalized by 1e-4
else:
raise NotImplementedError
def transform_params(self, cosmo, **params):
if 'velocileptors' in self.model:
# FIXME (couldn't use set_params as requires pt attribute)
if self.is_physical_prior:
raise NotImplementedError
sigma8 = 1.
f = 0.
pars = b1L, b2L, bsL, b3L = [params['b1p'] / sigma8 - 1., params['b2p'] / sigma8**2, params['bsp'] / sigma8**2, params['b3p'] / sigma8**3]
pars += [(1 + b1L)**2 * params['alpha0p'], f * (1 + b1L) * (params['alpha0p'] + params['alpha2p']),
f * (f * params['alpha2p'] + (1 + b1L) * params['alpha4p']), f**2 * params['alpha4p']]
sigv = self.options['sigv']
pars += [params['sn{:d}p'.format(i)] * self.snd * (self.fsat if i > 0 else 1.) * sigv**i for i in [0, 2, 4]]
else:
pars = [params[name] for name in self.required_bias_params]
if 'rept' in self.model:
pars = list(pars)
b1 = pars[0]
pars[2] = pars[2] - (2 / 7) * (b1 - 1.) # bs
pars[3] = 3 * pars[3] + (b1 - 1.) # b3
return pars
else:
raise NotImplementedError
return
def initialize(self, *args, model='velocileptors_rept_mnuw0wacdm', cosmo=None, fiducial='DESI', shotnoise=1e4, z=0., **kwargs):
self.options = dict()
shotnoise = kwargs.get('shotnoise', 1e4)
if utils.is_sequence(shotnoise):
# cross correlation: geometric mean
shotnoise = np.sqrt(np.prod(shotnoise))
for name, value in self._default_options.items():
self.options[name] = kwargs.pop(name, value)
if 'shotnoise' in self.options:
self.options['shotnoise'] = shotnoise
self.nd = 1. / float(shotnoise)
self.z = float(z)
# Sets k, ells
super(JAXEffortTracerPowerSpectrumMultipoles, self).initialize(*args, **kwargs)
self.nd = 1. / float(shotnoise)
self.fiducial = get_cosmo(fiducial)
self.cosmo = cosmo
if cosmo is None:
self.cosmo = Cosmoprimo(fiducial=self.fiducial)
self.apeffect = APEffect(z=self.z, fiducial=self.fiducial, mode='geometry', cosmo=self.cosmo)
self.required_bias_params, self.optional_bias_params = {}, {}
self.model = model
self.set_params()
import jaxeffort
self.emulators = [jaxeffort.trained_emulators[model][f"{ell:d}"] for ell in self.ells]
self.set_k_mu(self.k, mu=8, ells=self.ells)
def set_k_mu(self, k, mu=20, method='leggauss', ells=(0, 2, 4)):
self.k = np.asarray(k, dtype='f8')
self.mu, wmu = utils.weights_mu(mu, method=method)
self.wmu = np.array([wmu * (2 * ell + 1) * get_legendre(ell)(self.mu) for ell in ells])
def calculate(self, **params):
cosmo_dict = {'ln10As': self.cosmo['logA'], 'ns': self.cosmo['n_s'], 'h': self.cosmo['H0'] / 100.,
'omega_b': self.cosmo['omega_b'], 'omega_c': self.cosmo['omega_cdm'], 'm_nu': self.cosmo['m_ncdm_tot'],
'w0': self.cosmo['w0_fld'], 'wa': self.cosmo['wa_fld']}
import jaxeffort
cosmo_jaxeffort = jaxeffort.W0WaCDMCosmology(**cosmo_dict)
theta = jnp.array([self.z, cosmo_dict["ln10As"], cosmo_dict["ns"], 100. * cosmo_dict["h"], cosmo_dict["omega_b"], cosmo_dict["omega_c"], cosmo_dict["m_nu"], cosmo_dict["w0"], cosmo_dict["wa"]])
D = cosmo_jaxeffort.D_z(self.z)
bias = self.transform_params(cosmo_jaxeffort, **params)
poles = [emulator.get_Pl(theta, bias, D) for emulator in self.emulators]
jac, kap, muap = self.apeffect.ap_k_mu(self.k, self.mu)
pkmu = sum(pole[:, None] * get_legendre(ell)(muap) for ell, pole in zip(self.ells, poles))
func = lambda kap, pkmu: interp1d(kap, self.emulators[0].P11.k_grid, pkmu)
pkmu = jac * jax.vmap(func, in_axes=1, out_axes=1)(kap, pkmu)
self.power = jnp.sum(pkmu * self.wmu[:, None, :], axis=-1)
[docs]
def get(self):
return self.power