"""Classes to handle parameters."""
import os
import re
import fnmatch
import copy
import numbers
try:
from collections.abc import Mapping
except ImportError:
from collections import Mapping
from itertools import chain as _chain
from itertools import repeat as _repeat
from itertools import starmap as _starmap
from operator import itemgetter as _itemgetter
import numpy as np
import scipy as sp
from .jax import numpy as jnp
from .jax import scipy as jsp
from .jax import rv_frozen, use_jax, register_pytree_node_class
from .io import BaseConfig
from . import mpi, utils
from .mpi import CurrentMPIComm
from .utils import BaseClass, NamespaceDict, deep_eq, is_path
[docs]
def decode_name(name, default_start=0, default_stop=None, default_step=1):
"""
Split ``name`` into strings and allowed index ranges.
>>> decode_name('a_[-4:5:2]_b_[0:2]')
['a_', '_b_'], [range(-4, 5, 2), range(0, 2, 1)]
Parameters
----------
name : str
Parameter name, e.g. ``a_[-4:5:2]``.
default_start : int, default=0
Range start to use as a default.
default_stop : int, default=None
Range stop to use as a default.
default_step : int, default=1
Range step to use as a default.
Returns
-------
strings : list
List of strings.
ranges : list
List of ranges.
"""
name = str(name)
#replaces = re.finditer(r'\[(-?\d*):(\d*):*(-?\d*)\]', name)
replaces = re.finditer(r'\[([-+]?\d*):([-+]?\d*):*([-+]?\d*)\]', name)
strings, ranges = [], []
string_start = 0
for ireplace, replace in enumerate(replaces):
start, stop, step = replace.groups()
if not start:
start = default_start
if start is None:
raise ValueError('You must provide a lower limit to parameter index')
else: start = int(start)
if not stop:
stop = default_stop
if stop is None:
raise ValueError('You must provide an upper limit to parameter index')
else: stop = int(stop)
if not step:
step = default_step
if step is None:
raise ValueError('You must provide a step for parameter index')
else: step = int(step)
strings.append(name[string_start:replace.start()])
string_start = replace.end()
ranges.append(range(start, stop, step))
strings += [name[string_start:]]
return strings, ranges
[docs]
def yield_names_latex(name, latex=None, **kwargs):
r"""
Yield parameter name and latex strings with template forms ``[::]`` replaced.
>>> yield_names_latex('a_[-4:3:2]', latex='\alpha_[-4:5:2]')
a_-4, \alpha_{-4}
a_-2, \alpha_{-2}
a_-0, \alpha_{-0}
a_2, \alpha_{-2}
Parameters
----------
name : str
Parameter name.
latex : str, default=None
Latex for parameter.
kwargs : dict
Arguments for :func:`decode_name`
Returns
-------
name : str
Parameter name with template forms ``[::]`` replaced.
latex : str, None
If input ``latex`` is ``None``, ``None``.
Else latex string with template forms ``[::]`` replaced.
"""
strings, ranges = decode_name(name, **kwargs)
if not ranges:
yield strings[0], latex
else:
import itertools
template = '%d'.join(strings)
if latex is not None:
latex = latex.replace('[]', '%d')
for nums in itertools.product(*ranges):
yield template % nums, latex % nums if latex is not None else latex
[docs]
def find_names(allnames, name, quiet=True):
"""
Search parameter name ``name`` in list of names ``allnames``,
matching template forms ``[::]``;
return corresponding parameter names.
Contrary to :func:`find_names_latex`, it does not handle latex strings,
but can take a list of parameter names as ``name``
(thus returning the concatenated list of matching names in ``allnames``).
>>> find_names(['a_1', 'a_2', 'b_1', 'c_2'], ['a_[:]', 'b_[:]'])
['a_1', 'a_2', 'b_1']
Parameters
----------
allnames : list
List of parameter names (strings).
name : list, str
List of parameter name(s) to match in ``allnames``.
quiet : bool, default=True
If ``False`` and no match for parameter name was found is ``allnames``, raise :class:`ParameterError`.
Returns
-------
toret : list
List of parameter names (strings).
"""
if not utils.is_sequence(allnames):
allnames = [allnames]
if utils.is_sequence(name):
toret = []
for nn in name: toret += find_names(allnames, nn, quiet=quiet)
return toret
error = ParameterError('No match found for {}'.format(name))
if isinstance(name, re.Pattern):
pattern = name
ranges = []
else:
#name = fnmatch.translate(name) # does weird things to -
name = name.replace('*', '.*?') + '$' # ? for non-greedy, $ to match end of string
strings, ranges = decode_name(name)
pattern = re.compile(r'([-+]?\d*)'.join(strings))
toret = []
for paramname in allnames:
match = re.match(pattern, paramname)
if match:
add = True
nums = []
for s, ra in zip(match.groups(), ranges):
idx = int(s)
nums.append(idx)
add = idx in ra # ra not in memory
if not add: break
if add:
toret.append(paramname)
if not toret and not quiet:
raise error
return toret
[docs]
class ParameterError(Exception):
"""Exception raised when issue with :class:`ParameterError`."""
[docs]
class Deriv(dict):
"""
This class encodes derivative orders.
It is a modification of- :class:`Counter` in https://github.com/python/cpython/blob/main/Lib/collections/__init__.py,
restricting to positive elements.
"""
# References:
# http://en.wikipedia.org/wiki/Multiset
# http://www.gnu.org/software/smalltalk/manual-base/html_node/Bag.html
# http://www.demo2s.com/Tutorial/Cpp/0380__set-multiset/Catalog0380__set-multiset.htm
# http://code.activestate.com/recipes/259174/
# Knuth, TAOCP Vol. II section 4.6.3
def __init__(self, iterable=None, /, **kwds):
r"""
Create a new, empty :class:`Deriv` object.
>>> c = Deriv() # a new, empty derivative object, i.e. zero lag
>>> c = Deriv(['x', Parameter('x'), 'y']) # a new derivative from the list of parameters w.r.t. derivatives are taken
>>> c = Deriv({'x': 2, 'y': 1}) # a new derivative from a mapping: :math:`\partial_{x}^{2} \partial y`
>>> c = Deriv(x=2, y=1) # a new derivative from keyword args
"""
super().__init__()
if isinstance(iterable, Deriv): # shortcut, saves ~1e-6 s
super().update(iterable)
return
if iterable is None or isinstance(iterable, (Mapping,)):
self.update(iterable, **kwds)
else:
iterable = (iterable,) if not utils.is_sequence(iterable) else iterable
if all(isinstance(param, (Parameter, str)) for param in iterable):
self.update((str(param) for param in iterable), **kwds)
else:
raise ValueError('Unable to make Deriv from {}'.format(iterable))
# ADM changes
def __setitem__(self, name, item):
if item > 0:
super(Deriv, self).__setitem__(name, item)
[docs]
def setdefault(self, name, item):
if item > 0:
super(Deriv, self).setdefault(name, item)
def __missing__(self, key):
"""The order of a derivative w.r.t. a parameter not in the :class:`Deriv` is zero."""
# Needed so that self[missing_item] does not raise KeyError
return 0
[docs]
def total(self):
"""Total derivative order."""
return sum(self.values())
[docs]
def most_common(self, n=None):
"""
List the ``n`` most common derivatives and their orders from the most
common to the least. If n is ``None``, then list all derivative orders.
>>> Deriv(['x', 'x', 'x', 'y', 'y', 'z', 'z', 't']).most_common(3)
[('x', 3), ('y', 2), ('z', 2)]
"""
# Emulate Bag.sortedByCount from Smalltalk
if n is None:
return sorted(self.items(), key=_itemgetter(1), reverse=True)
# Lazy import to speedup Python startup time
import heapq
return heapq.nlargest(n, self.items(), key=_itemgetter(1))
[docs]
def elements(self):
"""
Iterator over derivatives repeating each as many times as its order.
>>> c = Deriv('xxyyzz')
>>> sorted(c.elements())
['x', 'x', 'y', 'y', 'z', 'z']
"""
# Emulate Bag.do from Smalltalk and Multiset.begin from C++.
return _chain.from_iterable(_starmap(_repeat, self.items()))
[docs]
def update(self, iterable=None, /, **kwds):
"""
Like :meth:`dict.update` but add derivative orders instead of replacing them.
Source can be an iterable, a dictionary, or another :class:`Deriv` instance.
>>> c = Deriv('xy')
>>> c.update('x') # add derivatives from another iterable
>>> d = Deriv('xy')
>>> c.update(d)
>>> c['x'] # 3 'x' in 'xy', 'x', 'xy'
3
"""
if iterable is not None:
if isinstance(iterable, Mapping):
if self:
self_get = self.get
for elem, count in iterable.items():
self[elem] = count + self_get(elem, 0)
else:
# fast path when counter is empty
super().update(iterable)
else:
for elem in iterable:
self[elem] = self.get(elem, 0) + 1
# ADM changes
self._keep_positive()
if kwds:
self.update(kwds)
def __reduce__(self):
return self.__class__, (dict(self),)
def __delitem__(self, elem):
"""Like :meth:`dict.__delitem__` but does not raise :class:`KeyError` for missing values."""
if elem in self:
super().__delitem__(elem)
def __repr__(self):
if not self:
return f'{self.__class__.__name__}()'
try:
# dict() preserves the ordering returned by most_common()
d = dict(self.most_common())
except TypeError:
# handle case where values are not orderable
d = dict(self)
return f'{self.__class__.__name__}({d!r})'
def __eq__(self, other):
"""``True`` if all derivative orders agree. Missing derivatives are treated as zero-order derivatives."""
if not isinstance(other, Deriv):
return NotImplemented
return all(self[e] == other[e] for c in (self, other) for e in c)
def __le__(self, other):
"""``True`` if all derivative orders in ``self`` are less than those in ``other``."""
if not isinstance(other, Deriv):
return NotImplemented
return all(self[e] <= other[e] for c in (self, other) for e in c)
def __lt__(self, other):
"""``True`` if all derivative orders in ``self`` are strictly less than those in ``other``."""
if not isinstance(other, Deriv):
return NotImplemented
return self <= other and self != other
def __ge__(self, other):
"""``True`` if all derivative orders in ``self`` are greater than those in ``other``."""
if not isinstance(other, Deriv):
return NotImplemented
return all(self[e] >= other[e] for c in (self, other) for e in c)
def __gt__(self, other):
"""``True`` if all derivative orders in ``self`` are strictly greater than those in ``other``."""
if not isinstance(other, Deriv):
return NotImplemented
return self >= other and self != other
def __add__(self, other):
"""
Add derivative orders.
>>> Deriv('xxy') + Deriv('xyy')
Deriv({'x': 3, 'y': 3})
"""
if not isinstance(other, Deriv):
return NotImplemented
result = Deriv()
for elem, count in self.items():
newcount = count + other[elem]
if newcount > 0:
result[elem] = newcount
for elem, count in other.items():
if elem not in self and count > 0:
result[elem] = count
return result
def _keep_positive(self):
"""Internal method to strip derivatives with a negative or zero order"""
nonpositive = [elem for elem, count in self.items() if not count > 0]
for elem in nonpositive:
del self[elem]
return self
def __iadd__(self, other):
"""
Inplace add from another derivative.
>>> c = Deriv('xxy')
>>> c += Deriv('xyy')
>>> c
Deriv({'x': 3, 'y': 3})
"""
for elem, count in other.items():
self[elem] += count
return self._keep_positive()
import numpy.lib.mixins
[docs]
@register_pytree_node_class
class ParameterArray(numpy.lib.mixins.NDArrayOperatorsMixin):
def __init__(self, value, param=None, derivs=None, copy=False, dtype=None, **kwargs):
"""
Initalize :class:`ParameterArray`.
Parameters
----------
value : array
Array value.
param : Parameter, str, default=None
Parameter.
derivs : list
List of derivatives (:class:`Deriv` instances).
copy : bool, default=False
Whether to copy input array.
dtype : dtype, default=None
If provided, enforce this dtype.
**kwargs : dict
Optional arguments for :func:`np.array`.
"""
if isinstance(value, ParameterArray):
value = value.value
if value is not None and (copy or dtype or (not use_jax(value) and not isinstance(value, np.ndarray))):
if copy:
value = np.array(value, copy=copy, dtype=dtype, **kwargs)
else: # numpy>=2
value = np.asarray(value, dtype=dtype, **kwargs)
self._value = value
self.param = None if param is None else Parameter(param)
self._derivs = None if derivs is None else tuple(Deriv(deriv) for deriv in derivs)
@property
def value(self):
return self._value
@property
def derivs(self):
return self._derivs
@property
def shape(self):
return self.value.shape
def __float__(self):
return float(self.value)
def __len__(self):
return len(self.value)
def __bool__(self):
return self.value.__bool__()
def __iter__(self):
values = self.value.__iter__() # to raise TypeError in case of 0d array
# yield would not raise an error in case of 0d array
def get(value):
new = self.__class__(value)
new.__array_finalize__(self, copy=True)
return new
return (get(value) for value in values)
@property
def zero(self):
"""Return zero-order derivative."""
if self.derivs is not None:
return self[()]
return self
@property
def pndim(self):
"""Number of dimensions of stored parameter, plus 1 if derivatives."""
return int(self.derivs is not None) + (self.param.ndim if self.param is not None else 0)
@property
def andim(self):
"""Number of dimensions of array, minus parameter dimensions and derivatives (if any)."""
return self.value.ndim - self.pndim
@property
def pshape(self):
"""Parameter shape, including derivatives along first dimension (if any)."""
return self.value.shape[self.andim:]
@property
def ashape(self):
"""Array shape, removing parameter shape and derivatives (if any)."""
return self.value.shape[:self.andim]
@ashape.setter
def ashape(self, shape):
self.value.shape = self.ashape + tuple(shape)
def __array_finalize__(self, obj, copy=False):
if obj.derivs is not None and (self.shape[-obj.pndim:] == obj.shape[-obj.pndim:]):
self._derivs = tuple(Deriv(deriv) for deriv in obj.derivs) if copy else obj.derivs
if obj.param is not None:
self.param = Parameter(obj.param) if copy else obj.param
def __copy__(self):
return self.clone(value=self.value.copy())
def copy(self, *args, **kwargs):
return self.__copy__(*args, **kwargs)
def __repr__(self):
return '{}({}, {}, {})'.format(self.__class__.__name__, self.param, self.derivs, self.value)
def __array__(self, *args, **kwargs):
return np.asarray(self._value, *args, **kwargs)
def __jax_array__(self, *args, **kwargs):
return jnp.asarray(self._value, *args, **kwargs)
def __format__(self, *args, **kwargs):
return self._value.__format__(*args, **kwargs)
def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs):
# Only authorise operations between arrays of same parameter / derivs
input_param_derivs = [(input.param, input.derivs) for input in inputs if isinstance(input, self.__class__)]
input_values = [input.value if isinstance(input, self.__class__) else input for input in inputs]
if isinstance(out, self.__class__):
new = self.__class__(getattr(ufunc, method)(*input_values, out=out, **kwargs))
new.__array_finalize__(self)
return new
#if use_jax(self._value, *inputs):
# ufunc = getattr(jnp, ufunc.__name__, ufunc)
new = getattr(ufunc, method)(*input_values, **kwargs)
if input_param_derivs:
param, derivs = input_param_derivs[0]
if any(param_derivs[0] != param for param_derivs in input_param_derivs[1:]):
param = None
if any(param_derivs[1] != derivs for param_derivs in input_param_derivs[1:]):
derivs = None
if param is not None:
param = Parameter(param)
if derivs is not None:
if (new.shape[-self.pndim:] == self.shape[-self.pndim:]):
derivs = tuple(Deriv(deriv) for deriv in derivs)
else:
derivs = None
new = self.__class__(new, param=param, derivs=derivs)
return new
def _isderiv(self, deriv):
try:
deriv = Deriv(deriv)
return deriv, True
except ValueError:
return deriv, False
[docs]
def isin(self, deriv):
"""Test if input deriv in array."""
deriv, isderiv = self._isderiv(deriv)
if isderiv:
return (self.derivs is not None) and (deriv in self.derivs)
return np.isin(deriv, self)
def _index(self, index):
toret = index
deriv, isderiv = self._isderiv(index)
if isderiv:
if self.derivs is not None:
try:
ideriv = self.derivs.index(deriv)
except ValueError as exc:
raise KeyError('{} is not in computed derivatives: {}'.format(deriv, self.derivs)) from exc
else:
toret = (Ellipsis, ideriv)
if self.param is not None:
toret += (slice(None),) * self.param.ndim
elif deriv:
raise KeyError('Array has no derivatives')
else:
toret = Ellipsis
return toret
def __getitem__(self, deriv):
"""Derivative w.r.t. parameter 'a' can be obtained (if exists) as array[('a',)]."""
deriv, isderiv = self._isderiv(deriv)
#return self.__class__(self.value.__getitem__(self._index(deriv)), param=self.param, derivs=None if isderiv else self.derivs)
if isderiv:
return self.__class__(self.value.__getitem__(self._index(deriv)), param=self.param, derivs=None)
new = self.__class__(self.value.__getitem__(deriv))
new.__array_finalize__(self, copy=True)
return new
def __setitem__(self, deriv, item):
"""Derivative w.r.t. parameter 'a' can be set (if exists) as array[('a',)] = deriv."""
return self.value.__setitem__(self._index(deriv), item)
def __setstate__(self, state):
self._value = state['value']
self.param = state['param']
if self.param is not None: self.param = Parameter.from_state(self.param) # Set the info attribute
self._derivs = state['derivs']
if self._derivs is not None:
self._derivs = tuple(Deriv(deriv) for deriv in self._derivs)
def __getstate__(self):
state = {name: getattr(self, name) for name in ['value', 'param', 'derivs']}
if self.param is not None: state['param'] = self.param.__getstate__()
if self.derivs is not None: state['derivs'] = [dict(deriv) for deriv in self.derivs]
return state
def __getattr__(self, name):
return object.__getattribute__(self._value, name)
[docs]
@classmethod
def from_state(cls, state):
"""Create :class:`ParameterArray` for state (dictionary)."""
return cls(state['value'], None if state.get('param', None) is None else Parameter.from_state(state['param']), state.get('derivs', None))
[docs]
def clone(self, **kwargs):
"""Clone :class:`ParameterArray`, optionally updating :attr:`value`, :attr:`param` or :attr:`derivs`."""
state = {name: getattr(self, name) for name in ['value', 'param', 'derivs']}
state.update(**kwargs)
return self.__class__(**state)
def tree_flatten(self):
return (self.value,), {name: getattr(self, name) for name in ['param', 'derivs']}
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
def get_wrapper(func):
def wrapper(self, *args, **kwargs):
new = self.__class__(getattr(self.value, func)(*args, **kwargs))
new.__array_finalize__(self, copy=True)
return new
return wrapper
for name in ['ravel', 'reshape']:
setattr(ParameterArray, name, get_wrapper(name))
[docs]
class Parameter(BaseClass):
"""Class that represents a parameter."""
_attrs = ['basename', 'namespace', 'value', 'fixed', 'derived', 'prior', 'ref', 'proposal', 'delta', 'latex', 'depends', 'shape', 'drop']
_allowed_solved = ['.best', '.marg', '.auto', '.best_not_derived', '.marg_not_derived', '.auto_not_derived', '.prec']
#_allowed_solved += ['best', 'marg', 'auto', 'best_not_derived', 'marg_not_derived', 'auto_not_derived', 'prec']
def __init__(self, basename, namespace='', value=None, fixed=None, derived=False, prior=None, ref=None, proposal=None, delta=None, latex=None, shape=(), drop=False):
"""
Initialize :class:`Parameter`.
Parameters
----------
basename : str
Parameter base name (which defines parameter meaning).
If :class:`Parameter`, update ``self`` attributes.
namespace : str, default=''
Parameter namespace (to differentiate several occurences of the same parameter in the same pipeline).
value : float, default=None
Default value for parameter.
fixed : bool, default=None
Whether parameter is fixed.
If ``None``, defaults to ``True`` if ``prior`` or ``ref`` is not ``None``, else ``False``.
derived : bool, str, default=False
``True`` if parameter is taken from a calulator's attributes (or :meth:`BaseCalculator.__getstate__` at run time).
'.best', '.marg', or '.auto' to solve for this parameter (given a Gaussian likelihood),
respectively taking the best-fit solution, performing analytic marginalization, or choosing
between these two options depending on whether a profiler ('.best') or a sampler ('.marg') is used.
When using analytic marginalization, the hessian of the loglikelihood and of the prior is stored (as 'derived' parameters);
potentially yielding large (in terms of memory space) chains. To circumvent this, you can provide e.g. '.auto_not_derived'.
'.prec' can be used for linear parameters for which the gradient does not depend on the value of other parameters;
in this case, the likelihood's precision matrix is marginalized over this parameter at initialization, and the parameter
is ignored in the profiling / sampling.
One can also define the value of this parameter as a function of others (e.g. 'a', 'b'), by providing
e.g. the string '{a} + {b}' (or any other operation; numpy is available with 'np', scipy with 'sp',
and their jax version with 'jnp' and 'jsp').
prior : ParameterPrior, dict, default=None
Prior distribution for parameter, arguments for :class:`ParameterPrior`.
ref : Prior, dict, default=None
Reference distribution for parameter, arguments for :class:`ParameterPrior`.
This is supposed to represent the expected posterior for this parameter.
If ``None``, defaults to ``prior``.
proposal : float, default=None
Proposal uncertainty for parameter.
If ``None``, defaults to ``ref.std()``.
delta : float, tuple, detault=None
Variation for finite-differentiation, w.r.t. ``value``.
If tuple, (variation below value, variation above value),
e.g.: ``(0.1, 0.2)``, with ``value = 1``, means a variation range ``(0.9, 1.2)``.
latex : str, default=None
Latex string for parameter.
shape : tuple, default=()
Parameter shape; typically non-trivial when ``derived`` is ``True``.
drop : bool, default=False
If ``True``, this parameter will not be provided to the calculator.
"""
from . import base
if isinstance(basename, Parameter):
self.__dict__.update(basename.__dict__)
return
if isinstance(basename, ParameterConfig):
self.__dict__.update(basename.init().__dict__)
return
try:
basename = dict(basename)
except (ValueError, TypeError):
pass
else:
if 'name' in basename:
basename['basename'] = basename.pop('name')
self.__init__(**basename)
return
if namespace is None: self._namespace = ''
else: self._namespace = str(namespace)
names = str(basename).split(base.namespace_delimiter)
self._basename, namespace = names[-1], base.namespace_delimiter.join(names[:-1])
if namespace:
if self._namespace: self._namespace = base.namespace_delimiter.join([self._namespace, namespace])
else: self._namespace = namespace
self._value = float(value) if value is not None else None
self._prior = prior if isinstance(prior, ParameterPrior) else ParameterPrior(**(prior or {}))
if ref is not None:
self._ref = ref if isinstance(ref, ParameterPrior) else ParameterPrior(**(ref or {}))
else:
self._ref = self._prior.copy()
self._latex = latex
self._proposal = proposal
self._delta = delta
if delta is not None:
if np.ndim(delta) == 0:
delta = (delta,) * 2
self._delta = tuple(delta)
self._derived = derived
self._depends = {}
if isinstance(derived, str):
if self._derived in self._allowed_solved:
allowed_dists = ['norm', 'uniform']
if self._prior.dist not in allowed_dists or self._prior.is_limited():
raise ParameterError('Prior must be one of {}, with no limits, to use analytic marginalisation for {}'.format(allowed_dists, self))
else:
placeholders = re.finditer(r'\{.*?\}', derived)
nderived = len(derived)
for placeholder in placeholders:
placeholder = placeholder.group()
if placeholder not in derived: continue # already replaced
key = '_' * nderived + '{:d}_'.format(len(self._depends) + 1)
assert key not in derived
derived = derived.replace(placeholder, key)
self._depends[key] = placeholder[1:-1]
self._derived = derived
else:
self._derived = bool(self._derived)
if fixed is None:
fixed = prior is None and ref is None and not self.depends
self._fixed = bool(fixed)
self._shape = tuple(int(s) for s in (shape if utils.is_sequence(shape) else (shape,)))
self._drop = bool(drop)
self.updated = True
@property
def size(self):
"""Parameter size, typically non-zero when :attr:`derived` is ``True``."""
return np.prod(self._shape, dtype='i')
@property
def ndim(self):
"""Parameter dimension, typically non-trivial when :attr:`derived` is ``True``."""
return len(self._shape)
[docs]
def eval(self, **values):
"""
Return parameter value, given all parameter values, e.g. if :attr:`derived` is '{a} + {b}',
>>> param.eval(a=2., b=3.)
5.
"""
if isinstance(self._derived, str) and self._derived not in self._allowed_solved:
try:
values = {k: values[n] for k, n in self._depends.items()}
except KeyError:
raise ParameterError('Parameter {} is to be derived from parameters {}, as {}, but they are not provided'.format(self, list(self._depends.values()), self.derived))
return utils.evaluate(self._derived, locals=values)
return values[self.name]
@property
def value(self):
"""Default value for parameter; if not specified, defaults ``ref.center()``."""
value = self._value
if value is None:
try:
value = self._ref.center()
except AttributeError as exc:
raise AttributeError('reference distribution has no center(), probably because it is not proper... provide value argument or proper reference distribution') from exc
return value
@property
def proposal(self):
"""Proposal uncertainty for parameter; if not specified, defaults to ``ref.std()``."""
proposal = self._proposal
if proposal is None:
try:
proposal = self._ref.std()
except AttributeError as exc:
raise AttributeError('reference distribution has no std(), probably because it is not proper... provide proposal argument or proper reference distribution') from exc
return proposal
@property
def delta(self):
"""
Variation for finite-differentiation;
e.g.: ``(1., 0.1, 0.2)``, means a variation range ``(0.9, 1.2)``.
If not specified, defaults to ``(value, 0.1 * proposal, 0.1 * proposal)`` (further limited by prior bounds if any).
"""
delta = self._delta
proposal_scale = 1e-1
if delta is None:
try:
proposal = self.proposal
except AttributeError as exc:
raise AttributeError('reference distribution has no std(), probably because it is not proper... provide delta argument, or proposal, or proper reference distribution') from exc
delta = (proposal_scale * proposal, proposal_scale * proposal)
#center = self.value
#delta = (min(delta[0], center - self.prior.limits[0]), min(delta[1], self.prior.limits[1] - center))
if len(delta) == 2:
delta = (self.value,) + tuple(delta)
return delta
@property
def derived(self):
"""If parameter is derived from others 'a', 'b', return e.g. '{a} * {b}'."""
if isinstance(self._derived, str) and self._derived not in self._allowed_solved:
toret = self._derived
for k, v in self._depends.items():
toret = toret.replace(k, '{{{}}}'.format(v))
return toret
return self._derived
@property
def solved(self):
"""Whether parameter is solved, i.e. fixed at best fit or marginalized over."""
return (not self._fixed) and self._derived in self._allowed_solved
@property
def input(self):
"""Whether parameter should be fed as input to calculator."""
return ((self._derived is False) or isinstance(self._derived, str)) and not self.depends
@property
def name(self):
"""Return parameter name, as namespace.basename if :attr:`namespace` is not ``None``, else basename."""
from . import base
if self._namespace:
return base.namespace_delimiter.join([self._namespace, self._basename])
return self._basename
[docs]
def update(self, *args, **kwargs):
"""Update parameter attributes with new arguments ``kwargs``."""
state = self.__getstate__()
if len(args) == 1 and isinstance(args[0], self.__class__):
state.update(args[0].__getstate__())
elif len(args) == 1 and isinstance(args[0], ParameterConfig):
state = ParameterConfig(self).clone(args[0]).init().__getstate__()
elif len(args):
raise ValueError('Unrecognized arguments {}'.format(args))
if 'name' in kwargs:
kwargs['basename'] = kwargs.pop('name')
kwargs['namespace'] = None
state.update(kwargs)
state.pop('updated', None)
self.__init__(**state)
[docs]
def clone(self, *args, **kwargs):
"""Clone parameter, i.e. copy and update."""
new = self.copy()
new.update(*args, **kwargs)
return new
@property
def varied(self):
"""Whether parameter is varied (i.e. not fixed)."""
return (not self._fixed)
@property
def limits(self):
"""Parameter limits."""
return self._prior.limits
def __copy__(self):
"""Shallow copy."""
new = super(Parameter, self).__copy__()
new._depends = copy.copy(new._depends)
return new
[docs]
def deepcopy(self):
"""Deep copy."""
return copy.deepcopy(self)
def __getstate__(self):
"""Return this class' state dictionary."""
state = {}
for key in self._attrs:
state[key] = getattr(self, '_' + key)
if key in ['prior', 'ref']:
state[key] = state[key].__getstate__()
state['derived'] = self.derived
state.pop('depends')
state['updated'] = self.updated
return state
def __setstate__(self, state):
"""Set this class' state dictionary."""
state = state.copy()
updated = state.pop('updated', True)
# For backward-compatibility
state.pop('saved', None)
self.__init__(**state)
self.updated = updated
def __repr__(self):
"""Represent parameter as string (name and fixed or varied)."""
return '{}({}, {})'.format(self.__class__.__name__, self.name, 'fixed' if self._fixed else 'varied')
def __str__(self):
"""Return parameter as string (name)."""
return str(self.name)
def __eq__(self, other):
"""Is ``self`` equal to ``other``, i.e. same type and attributes?"""
return type(other) == type(self) and all(deep_eq(getattr(other, '_' + name), getattr(self, '_' + name)) for name in self._attrs)
def __diff__(self, other):
toret = {}
for name in self._attrs:
self_value = getattr(self, '_' + name)
other_value = getattr(other, '_' + name)
if not deep_eq(self_value, other_value):
toret[name] = (self_value, other_value)
return toret
def __hash__(self):
return hash(str(self))
[docs]
def latex(self, namespace=None, inline=False):
"""
Return latex string for parameter if :attr:`latex` is specified (i.e. not ``None``), else :attr:`name`.
Parameters
----------
namespace : bool, str, default=None
If ``False``, no namespace is added to the latex string.
If ``True``, :attr:`namespace` is turned into a latex string, and added as a subscript.
If string, add this subscript to the latex string.
If ``None``, and none of :attr:`namespace` "words" (defined as group of characters separated by ',', ' ', '_', '-')
are in the current latex string, then same as ``True``; else, same as ``False``.
inline : bool, default=False
If ``True``, add '$' around the latex string.
Returns
-------
latex : str
Latex string.
"""
auto_namespace = namespace is None
force_namespace = namespace is True
provided_namespace = False
if force_namespace or auto_namespace:
namespace = str(self._namespace)
elif namespace is not False:
namespace = str(namespace)
provided_namespace = force_namespace = True
if self._latex is not None:
def add_namespace(group):
words = re.split(', |_|-', namespace) # parse namespace
for word in words:
if word in self._latex and word not in self.basename:
return False
return True
latex = self._latex
if namespace and (force_namespace or auto_namespace):
match1 = re.match('(.*)_(.)$', self._latex)
match2 = re.match('(.*)_{(.*)}$', self._latex)
latex_namespace = namespace if provided_namespace else (r'\mathrm{%s}' % namespace.replace('\_', '_').replace('_', '\_'))
for match in [match1, match2, None]:
if match is not None:
if force_namespace or (auto_namespace and add_namespace(match.group(2))): # check namespace is not in latex str already
latex = r'%s_{%s, %s}' % (match.group(1), match.group(2), latex_namespace)
break
elif force_namespace or (auto_namespace and add_namespace(namespace)):
latex = r'%s_{%s}' % (self._latex, latex_namespace)
if inline:
latex = '${}$'.format(latex)
return latex
return str(self.name)
def _make_property(name):
def getter(self):
return getattr(self, '_' + name)
return getter
for name in Parameter._attrs:
if name not in ['value', 'proposal', 'delta', 'derived', 'latex']:
setattr(Parameter, name, property(_make_property(name)))
[docs]
class BaseParameterCollection(BaseClass):
"""Base class holding a collection of items identified by parameter."""
_type = Parameter
_attrs = ['attrs']
@classmethod
def _get_name(cls, item):
if isinstance(item, str):
return item
if isinstance(item, Parameter):
param = item
else:
param = cls._get_param(item)
if param is None:
return None
return str(param.name)
@classmethod
def _get_param(cls, item):
return item
def __init__(self, data=None, attrs=None):
"""
Initialize :class:`BaseParameterCollection`.
Parameters
----------
data : list, tuple, str, dict, ParameterCollection
Can be:
- list (or tuple) of items
- dictionary mapping name to item
- :class:`BaseParameterCollection` instance
attrs : dict, default=None
Optionally, other attributes, stored in :attr:`attrs`.
"""
if isinstance(data, self.__class__):
self.__dict__.update(data.copy().__dict__)
return
self.attrs = dict(attrs or {})
self.data = []
if data is None:
return
if utils.is_sequence(data):
dd = data
for item in dd:
self[self._get_name(item)] = item # only name is provided
else:
for name, item in data.items():
self[name] = item
def __setitem__(self, name, item):
"""
Update parameter in collection.
Parameters
----------
name : Parameter, str, int
Parameter name.
If :class:`Parameter` instance, search for parameter with same name.
If integer, index in collection.
item : Parameter
Parameter.
"""
if not isinstance(item, self._type):
raise TypeError('{} is not a {} instance.'.format(item, self._type))
try:
self.data[name] = item # list index
except TypeError:
item_name = self._get_name(item)
if self._get_name(name) != item_name:
raise KeyError('Parameter {} must be indexed by name (incorrect {})'.format(item_name, name))
self.set(item)
def __getitem__(self, name):
"""
Return item corresponding to parameter ``name``.
Parameters
----------
name : Parameter, str, int
Parameter name.
If :class:`Parameter` instance, search for parameter with same name.
If integer, index in collection.
"""
try:
return self.data[name]
except TypeError:
return self.data[self.index(name)]
def __delitem__(self, name):
"""
Delete parameter ``name``.
Parameters
----------
name : Parameter, str, int
Parameter name.
If :class:`Parameter` instance, search for parameter with same name.
If integer, index in collection.
"""
try:
del self.data[name]
except TypeError:
del self.data[self.index(name)]
[docs]
def sort(self, key=None):
"""
Sort (in-place) collection, such that if follows the list of parameter names ``key``.
If ``None``, no sorting is performed.
"""
if key is not None:
self.data = [self[kk] for kk in key]
else:
self.data = self.data.copy()
return self
[docs]
def pop(self, name, *args, **kwargs):
"""Remove and return item indexed by ``name``."""
toret = self.get(name, *args, **kwargs)
try:
del self[name]
except (IndexError, KeyError):
pass
return toret
[docs]
def get(self, name, *args, **kwargs):
"""
Return item of parameter name ``name`` in collection.
Parameters
----------
name : Parameter, str
Parameter name.
If :class:`Parameter` instance, search for parameter with same name.
"""
has_default = False
if args:
if len(args) > 1:
raise SyntaxError('Too many arguments!')
has_default = True
default = args[0]
if kwargs:
if len(kwargs) > 1:
raise SyntaxError('Too many arguments!')
has_default = True
default = kwargs['default']
try:
return self[name]
except KeyError:
if has_default:
return default
raise KeyError('Parameter {} not found'.format(name))
[docs]
def set(self, item):
"""
Set item in collection.
If there is already a parameter with same name in collection, replace this stored item by the input one.
Else, append item to collection.
"""
try:
self.data[self.index(item)] = item
except KeyError:
self.data.append(item)
[docs]
def setdefault(self, item):
"""Set item in collection if not already in it."""
if not isinstance(item, self._type):
raise TypeError('{} is not a {} instance.'.format(item, self._type))
if item not in self:
self.set(item)
[docs]
def index(self, name):
"""
Return index of parameter ``name``.
Parameters
----------
name : Parameter, str, int
Parameter name.
If :class:`Parameter` instance, search for parameter with same name.
If integer, index in collection.
Returns
-------
index : int
"""
return self._index_name(self._get_name(name))
def _index_name(self, name):
# get index of parameter name ``name``
for ii, item in enumerate(self.data):
if self._get_name(item) == name:
return ii
raise KeyError('Parameter {} not found'.format(name))
def __contains__(self, name):
"""Whether collection contains parameter ``name``."""
try:
self._index_name(self._get_name(name))
return True
except KeyError:
return False
def _select(self, **kwargs):
toret = self.copy()
if not kwargs:
return toret
toret.clear()
for item in self:
param = self._get_param(item)
match = True
for key, value in kwargs.items():
param_value = getattr(param, key)
if key in ['name', 'basename', 'namespace']:
key_match = value is None or bool(find_names([param_value], value))
else:
key_match = deep_eq(value, param_value)
if not key_match:
try:
key_match |= any(deep_eq(v, param_value) for v in value)
except TypeError:
pass
match &= key_match
if not key_match: break
if match:
toret.data.append(item)
return toret
[docs]
def select(self, **kwargs):
"""
Return new collection, after selection of parameters whose attribute match input values::
collection.select(fixed=True)
returns collection of fixed parameters.
If 'name' is provided, consider all matching parameters, e.g.::
collection.select(varied=True, name='a_[0:2]')
returns a collection of varied parameters, with name in ``['a_0', 'a_1']``.
"""
return self._select(**kwargs)
[docs]
def params(self, **kwargs):
"""Return :class:`ParameterCollection`, collection of parameters corresponding to items stored in this collection."""
return ParameterCollection([self._get_param(item) for item in self._select(**kwargs)])
[docs]
def names(self, **kwargs):
"""Return parameter names in collection."""
params = self.params(**kwargs)
return [param.name for param in params]
[docs]
def basenames(self, **kwargs):
"""Return base parameter names in collection."""
params = self.params(**kwargs)
return [param.basename for param in params]
[docs]
@classmethod
def concatenate(cls, *others):
"""
Concatenate input collections.
Unique items only are kept.
"""
if not others: return cls()
if len(others) == 1 and utils.is_sequence(others[0]):
others = others[0]
new = cls(others[0])
for other in others[1:]:
other = cls(other)
for item in other.data:
new.set(item)
return new
[docs]
def extend(self, other):
"""
Extend collection with ``other``.
Unique items only are kept.
"""
new = self.concatenate(self, other)
self.__dict__.update(new.__dict__)
def __repr__(self):
return '{}({})'.format(self.__class__.__name__, self.names())
def __len__(self):
"""Collection length, i.e. number of items."""
return len(self.data)
def __iter__(self):
"""Iterator on collection."""
return iter(self.data)
def __getstate__(self):
"""Return this class state dictionary."""
state = {'data': [item.__getstate__() for item in self.data]}
for name in self._attrs:
# if hasattr(self, name):
state[name] = getattr(self, name)
return state
def __setstate__(self, state):
"""Set this class state dictionary."""
super(BaseParameterCollection, self).__setstate__(state)
self.data = [self._type.from_state(item) for item in state['data']]
def __copy__(self):
new = super(BaseParameterCollection, self).__copy__()
for name in ['data'] + self._attrs:
# if hasattr(self, name):
setattr(new, name, copy.copy(getattr(new, name)))
return new
[docs]
def clear(self):
"""Empty collection."""
self.data.clear()
return self
[docs]
def update(self, *args, **kwargs):
"""
Update collection with new one; arguments can be a :class:`BaseParameterCollection`
or arguments to instantiate such a class (see :meth:`__init__`).
"""
if len(args) == 1 and isinstance(args[0], self.__class__):
other = args[0]
else:
other = self.__class__(*args, **kwargs)
for item in other:
self.set(item)
[docs]
def clone(self, *args, **kwargs):
"""Clone collection, i.e. (shallow) copy and update."""
new = self.copy()
new.update(*args, **kwargs)
return new
[docs]
def keys(self, **kwargs):
"""Return parameter names."""
return [self._get_name(item) for item in self._select(**kwargs)]
[docs]
def values(self, **kwargs):
"""Return items."""
return [item for item in self._select(**kwargs)]
[docs]
def items(self, **kwargs):
"""Return list of tuples (parameter name, item)."""
return [(self._get_name(item), item) for item in self._select(**kwargs)]
[docs]
def deepcopy(self):
"""Deep copy."""
return copy.deepcopy(self)
def __eq__(self, other):
"""Is ``self`` equal to ``other``, i.e. same type and attributes?"""
return type(other) == type(self) and list(other.params()) == list(self.params()) and all(deep_eq(other_value, self_value) for other_value, self_value in zip(other, self))
[docs]
class ParameterConfig(NamespaceDict):
"""A convenient object, used internally by the code, to store configuration for a given parameter."""
def __init__(self, conf=None, **kwargs):
if isinstance(conf, Parameter):
conf = conf.__getstate__()
if conf['namespace'] is None:
conf.pop('namespace')
conf.pop('updated', None)
super(ParameterConfig, self).__init__(conf, **kwargs)
for name in ['prior', 'ref']:
if name in self:
if isinstance(self[name], ParameterPrior):
self[name] = self[name].__getstate__()
def init(self):
state = self.__getstate__()
if not isinstance(self.get('namespace', None), str):
state['namespace'] = None
for name in ['prior', 'ref']:
if state.get(name, None) is not None and 'rescale' in state[name]:
value = copy.copy(state[name])
rescale = value.pop('rescale')
if 'scale' in value:
value['scale'] *= rescale
if 'limits' in value:
limits = np.array(value['limits'])
if np.isfinite(limits).all():
center = np.mean(limits)
value['limits'] = (limits - center) * rescale + center
state[name] = value
return Parameter(**state)
@property
def param(self):
return self.init()
def update(self, *args, exclude=(), **kwargs):
other = self.__class__(*args, **kwargs)
for name, value in other.items():
if name not in exclude:
if name in ['prior', 'ref'] and name in self and 'rescale' in value and len(value) == 1:
value = {**self[name], 'rescale': value['rescale']}
self[name] = copy.copy(value)
def update_derived(self, oldname, newname):
if self.get('derived', False) and isinstance(self.derived, str) and self.derived not in Parameter._allowed_solved:
self.derived = self.derived.replace('{{{}}}'.format(str(oldname)), '{{{}}}'.format(str(newname)))
@property
def name(self):
from . import base
namespace = self.get('namespace', None)
if isinstance(namespace, str) and namespace:
return base.namespace_delimiter.join([namespace, self.basename])
return self.basename
@name.setter
def name(self, name):
from . import base
names = str(name).split(base.namespace_delimiter)
if len(names) >= 2:
self.basename, self.namespace = names[-1], base.namespace_delimiter.join(names[:-1])
else:
self.basename = names[0]
[docs]
class ParameterCollectionConfig(BaseParameterCollection):
"""
A collection of :class:`ParameterConfig` objects, used internally by the code.
TODO: As :class:`ParameterConfig`, should be either documented and/or simplified.
"""
_type = ParameterConfig
_attrs = ['fixed', 'derived', 'namespace', 'delete', 'wildcard', 'identifier']
def _get_name(self, item):
from . import base
if isinstance(item, str):
return item.split(base.namespace_delimiter)[-1]
return getattr(item, self.identifier)
@classmethod
def _get_param(cls, item):
return item.param
def __init__(self, data=None, identifier='basename', **kwargs):
if isinstance(data, self.__class__):
self.__dict__.update(data.copy().__dict__)
self.identifier = identifier
return
self.identifier = identifier
if isinstance(data, ParameterCollection):
dd = data
data = {getattr(param, self.identifier): ParameterConfig(param) for param in data}
if len(data) != len(dd):
raise ValueError('Found different parameters with same {} in {}'.format(self.identifier, dd))
else:
data = BaseConfig(data, **kwargs).data
self.fixed, self.derived, self.namespace, self.delete, self.wildcard = {}, {}, {}, {}, []
for meta_name in ['fixed', 'varied', 'derived', 'namespace', 'delete']:
meta = data.pop('.{}'.format(meta_name), {})
if utils.is_sequence(meta):
meta = {name: True for name in meta}
elif not isinstance(meta, dict):
meta = {meta: True}
if meta_name in ['fixed', 'varied']:
self.fixed.update({name: (meta_name == 'fixed' and value) or (meta_name == 'varied' and not value) for name, value in meta.items()})
else:
getattr(self, meta_name).update({name: bool(value) for name, value in meta.items()})
self.data = []
for name, conf in data.items():
if isinstance(conf, numbers.Number):
conf = {'value': conf}
conf = dict(conf or {})
latex = conf.pop('latex', None)
for name, latex in yield_names_latex(name, latex=latex):
tmp = conf.__getstate__() if isinstance(conf, ParameterConfig) else conf
tmp = ParameterConfig(name=name, **tmp)
if latex is not None: tmp.latex = latex
if '*' in tmp[self.identifier]:
self.wildcard.append(tmp)
continue
self.set(tmp)
#for meta_name in ['fixed', 'derived', 'namespace']:
# meta = getattr(self, meta_name)
# if meta_name in tmp:
# meta.pop(name, None)
# meta[name] = tmp[meta_name]
self._set_meta()
def _set_meta(self):
for meta_name in ['fixed', 'derived', 'namespace']:
meta = getattr(self, meta_name)
for name in reversed(meta):
for tmpconf in self.select(**{self.identifier: name}):
if meta_name not in tmpconf:
tmpconf[meta_name] = meta[name]
# Wildcard
for conf in reversed(self.wildcard):
for tmpconf in self.select(**{self.identifier: conf[self.identifier]}):
tmpconf.update(conf.clone(tmpconf))
for conf in self:
if 'namespace' not in conf:
conf.namespace = True
def updated(self, param):
# Updated with meta?
paramname = self._get_name(param)
for meta_name in ['fixed', 'derived', 'namespace']:
meta = getattr(self, meta_name)
for name in meta:
if find_names([paramname], name):
return True
for conf in self.wildcard:
if find_names([paramname], conf[self.identifier]):
return True
return False
[docs]
def select(self, **kwargs):
toret = self.copy()
if not kwargs:
return toret
toret.clear()
for item in self:
param = item
match = True
for key, value in kwargs.items():
param_value = getattr(param, key)
if key in ['name', 'basename', 'namespace']:
key_match = value is None or bool(find_names([param_value], value))
else:
key_match = deep_eq(value, param_value)
if not key_match:
try:
key_match |= any(deep_eq(v, param_value) for v in value)
except TypeError:
pass
match &= key_match
if not key_match: break
if match:
toret.data.append(item)
return toret
[docs]
def update(self, *args, **kwargs):
other = self.__class__(*args, **kwargs)
for name, b in other.delete.items():
if b:
for param in self.select(**{self.identifier: name}):
del self[param[self.identifier]]
for meta_name in ['fixed', 'derived', 'namespace']:
meta = getattr(other, meta_name)
for name in meta:
for tmpconf in self.select(**{self.identifier: name}):
tmpconf[meta_name] = meta[name]
for conf in other.wildcard:
for tmpconf in self.select(**{self.identifier: conf[self.identifier]}):
self[tmpconf[self.identifier]] = tmpconf.clone(conf, exclude=('basename',))
for conf in other:
new = self.pop(conf[self.identifier], ParameterConfig()).clone(conf)
self.set(new)
def update_order(d1, d2):
toret = {name: value for name, value in d1.items() if name not in d2}
for name, value in d2.items():
toret[name] = value
return toret
for meta_name in ['fixed', 'derived', 'namespace', 'delete']:
setattr(self, meta_name, update_order(getattr(self, meta_name), getattr(other, meta_name)))
self.wildcard = self.wildcard + other.wildcard
self._set_meta()
def with_namespace(self, namespace=None):
new = self.deepcopy()
new._set_meta()
for name, param in new.items():
if not isinstance(param.namespace, str) and param.namespace:
param.namespace = namespace
for dparam in new:
dparam.update_derived(param.basename, param.name)
return new
def init(self, namespace=None):
return ParameterCollection([conf.param for conf in self.with_namespace(namespace=namespace)])
[docs]
def set(self, item):
if not isinstance(item, ParameterConfig):
item = ParameterConfig(item)
try:
self.data[self.index(item)] = item
except KeyError:
self.data.append(item)
def __setitem__(self, name, item):
if not isinstance(item, ParameterConfig):
item = ParameterConfig(item)
item.setdefault(self.identifier, name)
try:
self.data[name] = item
except TypeError:
item_name = self._get_name(item)
if str(name) != item_name:
raise KeyError('Parameter {} must be indexed by name (incorrect name {})'.format(item_name, name))
self.data[self._index_name(name)] = item
[docs]
class ParameterCollection(BaseParameterCollection):
"""
Class holding a collection of parameters.
It additionally keeps track whether the collection has been updated,
as used in :class:`BasePipeline`.
"""
_attrs = ['_updated']
def __init__(self, data=None, attrs=None):
"""
Initialize :class:`ParameterCollection`.
Parameters
----------
data : list, tuple, str, Path, dict, ParameterCollection
Can be:
- list (or tuple) of parameters (:class:`Parameter` or dictionary to initialize :class:`Parameter`)
- path to *yaml* defining a list of parameters
- dictionary mapping name to parameter
- :class:`ParameterCollection` instance
attrs : dict, default=None
Optionally, other attributes, stored in :attr:`attrs`.
"""
if is_path(data):
data = ParameterCollectionConfig(data)
if isinstance(data, ParameterCollectionConfig):
data = data.init()
if isinstance(data, self.__class__):
self.__dict__.update(data.copy().__dict__)
return
self.attrs = dict(attrs or {})
self.data = []
self._updated = True
if data is None:
return
if utils.is_sequence(data):
dd = data
data = {}
for name in dd:
if isinstance(name, Parameter):
data[name.name] = name
elif isinstance(name, dict):
data[name['name']] = name
else:
data[name] = {} # only name is provided
for name, conf in data.items():
if isinstance(conf, Parameter):
self.set(conf)
else:
if not isinstance(conf, dict): # parameter value
conf = {'value': conf}
else:
conf = conf.copy()
latex = conf.pop('latex', None)
for name, latex in yield_names_latex(name, latex=latex):
param = Parameter(basename=name, latex=latex, **conf)
self.set(param)
@property
def updated(self):
"""Whether the collection (the list of parameters itself of any of these parameters) has just been updated."""
return self._updated or any(param.updated for param in self.data)
@updated.setter
def updated(self, updated):
"""Set the 'updated' status."""
updated = bool(updated)
self._updated = updated
for param in self.data: param.updated = updated
def __delitem__(self, name):
"""
Delete parameter ``name``.
Parameters
----------
name : Parameter, str, int
Parameter name.
If :class:`Parameter` instance, search for parameter with same name.
If integer, index in collection.
"""
self._updated = True
return super(ParameterCollection, self).__delitem__(name)
[docs]
def update(self, *args, name=None, basename=None, **kwargs):
"""
Update collection with new one.
To e.g. fix parameters whose name matches the pattern 'a*':
>>> params.update(name='a*', fixed=True)
"""
self._updated = True
if len(args) == 1 and (isinstance(args[0], self.__class__) or utils.is_sequence(args[0])):
other = self.__class__(args[0])
self_basenames = self.basenames()
for item in other:
if item in self:
self[item] = self[item].clone(item)
elif basename is True and item.basename in self_basenames:
index = self_basenames.index(item.basename)
self[index] = self[index].clone(item)
else:
self.set(item.copy())
elif len(args) <= 1:
list_update = self.names(name=name, basename=basename)
for meta_name, fixed in zip(['fixed', 'varied'], [True, False]):
if meta_name in kwargs:
meta = kwargs[meta_name]
if isinstance(meta, bool):
if meta:
meta = list_update
else:
meta = []
for name in meta:
for name in self.names(name=name):
self[name] = self[name].clone(fixed=fixed)
if 'namespace' in kwargs:
namespace = kwargs['namespace']
indices = [self.index(name) for name in list_update]
oldnames = self.names()
for index in indices:
self.data[index] = self.data[index].clone(namespace=namespace)
newnames = self.names()
for index in indices:
dparam = self.data[index]
for k, v in dparam.depends.items():
if v in oldnames: dparam.depends[k] = newnames[oldnames.index(v)]
names = {}
for param in self.data: names[param.name] = names.get(param.name, 0) + 1
duplicates = {basename: multiplicity for basename, multiplicity in names.items() if multiplicity > 1}
if duplicates:
raise ValueError('Cannot update namespace, as following duplicates found: {} in {}'.format(duplicates, self))
else:
raise ValueError('Unrecognized arguments {}'.format(args))
def __add__(self, other):
"""Concatenate two parameter collections."""
return self.concatenate(self, self.__class__(other))
def __radd__(self, other):
if other == 0: return self.copy()
return self.__class__(other).__add__(self)
def __sub__(self, other):
"""Subtract two parameter collections."""
other = self.__class__(other)
return self.__class__([param for param in self if param not in other])
def __rsub__(self, other):
if other == 0: return self.copy()
return self.__class__(other).__sub__(self)
def __and__(self, other):
"""Intersection of two parameter collections."""
return self.__class__([param for param in self if param in other])
def __rand__(self, other):
if other == 0: return self.copy()
return self.__class__(other).__and__(self)
[docs]
def params(self, **kwargs):
"""Return a collection of parameters :class:`ParameterCollection`, with optional selection. See :meth:`select`."""
return self.select(**kwargs)
[docs]
def set(self, item):
"""
Set parameter in collection.
If there is already a parameter with same name in collection, replace this stored parameter by the input one.
Else, append parameter to collection.
"""
self._updated = True
if not isinstance(item, Parameter):
item = Parameter(item)
try:
self.data[self.index(item)] = item
except KeyError:
self.data.append(item)
def __setitem__(self, name, item):
"""
Update parameter in collection.
Parameters
----------
name : Parameter, str, int
Parameter name.
If :class:`Parameter` instance, search for parameter with same name.
If integer, index in collection.
item : Parameter
Parameter.
"""
self._updated = True
if not isinstance(item, Parameter):
if not isinstance(item, ParameterConfig):
try:
item = {'basename': name, **item}
except TypeError:
pass
item = Parameter(item)
try:
self.data[name] = item
except TypeError:
item_name = self._get_name(item)
if str(name) != item_name:
raise KeyError('Parameter {} must be indexed by name (incorrect {})'.format(item_name, name))
self.set(item)
[docs]
def eval(self, **params):
"""
Return parameter values, given all parameter values, e.g. if ``c.derived`` is '{a} + {b}',
>>> params.eval(a=2., b=3.)
{'a': 2., 'b': 3, 'c': 5}
See :meth:`Parameter.eval`.
"""
toret = {}
for param in self:
try:
toret[param.name] = param.eval(**params)
except (ParameterError, KeyError):
pass
return toret
[docs]
def prior(self, **params):
"""Compute total (log-)prior for input parameter values (except parameters that are solved)."""
eval_params = self.eval(**params)
toret = 0.
for param in self.data:
if param.varied and (param.depends or (not param.derived)) and param.name in eval_params:
toret += param.prior(eval_params[param.name])
return toret
[docs]
class ParameterPriorError(Exception):
"""Exception raised when issue with prior."""
[docs]
class ParameterPrior(BaseClass):
"""
Class that describes a 1D prior distribution.
TODO: make immutability explicit.
Parameters
----------
dist : str
Distribution name.
rv : rv_continuous
Random variate.
attrs : dict
Arguments used to initialize :attr:`rv`.
"""
def __init__(self, dist='uniform', limits=None, **kwargs):
r"""
Initialize :class:`ParameterPrior`.
Parameters
----------
dist : str
Distribution name in :mod:`jax.scipy.stats`, and as a fallback, :mod:`scipy.stats`.
limits : tuple, default=None
Tuple corresponding to lower, upper limits.
``None`` means :math:`-\infty` for lower bound and :math:`\infty` for upper bound.
Defaults to :math:`-\infty, +\infty`.
kwargs : dict
Arguments for distribution, typically ``loc``, ``scale``
(mean and standard deviation in case of a normal distribution ``'dist' == 'norm'``).
"""
if isinstance(dist, ParameterPrior):
self.__dict__.update(dist.__dict__)
return
if limits is None:
limits = (-np.inf, np.inf)
limits = list(limits)
if limits[0] is None: limits[0] = -np.inf
if limits[1] is None: limits[1] = np.inf
limits = tuple(limits)
if limits[1] <= limits[0]:
raise ParameterPriorError('ParameterPrior range {} has min greater than max'.format(limits))
self.limits = limits
self.attrs = dict(kwargs)
for name, value in self.attrs.items():
if name in ['loc', 'scale', 'a', 'b']: self.attrs[name] = float(value)
self.dist = str(dist)
if self.dist.startswith('trunc'): self.dist = self.dist[5:]
if self.is_limited():
dist = dist if dist == 'uniform' else 'trunc{}'.format(dist)
try:
dist = getattr(jsp.stats, dist)
except AttributeError:
try:
dist = getattr(sp.stats, dist)
except AttributeError:
raise AttributeError('Neither jax.scipy.stats nor scipy.stats have {} for attribute'.format(dist))
if self.is_limited():
if self.dist == 'uniform':
args = (self.limits[0], self.limits[1] - self.limits[0])
kwargs = {}
else:
loc, scale = self.attrs.get('loc', 0.), self.attrs.get('scale', 1.)
# See notes of https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html
limits = tuple((lim - loc) / scale for lim in limits)
args = limits
elif self.dist == 'uniform': # improper prior
return
else:
args = ()
self.rv = rv_frozen(dist, *args, **kwargs)
# self.limits = self.rv.support()
[docs]
def isin(self, x):
"""Whether ``x`` is within prior, i.e. within limits - strictly positive probability."""
x = jnp.asarray(x)
return (self.limits[0] < x) & (x < self.limits[1])
def __hash__(self):
return super().__hash__()
#@jit(static_argnums=[0, 2])
[docs]
def logpdf(self, x, remove_zerolag=True):
"""
Return log-probability density at ``x``.
If ``remove_zerolag`` is ``True``, remove the maximum log-probability density.
"""
_jnp = jnp if use_jax(x) else np # Worth testing if input is jax, as jax incurs huge overheads
x = _jnp.asarray(x)
isin = (self.limits[0] <= x) & (x <= self.limits[1])
# Fast version
if remove_zerolag:
if self.dist == 'uniform':
return _jnp.where(isin, 0, -np.inf)
if self.dist == 'norm':
return _jnp.where(isin, - 0.5 * (x - self.attrs['loc'])**2 / self.attrs['scale']**2, -np.inf)
if not self.is_proper():
return _jnp.where(isin, 0, -np.inf)
toret = self.rv.logpdf(x)
if remove_zerolag:
loc = self.attrs.get('loc', None)
if loc is None: loc = np.mean(self.limits)
toret -= self.rv.logpdf(loc)
return toret
def __call__(self, x, remove_zerolag=True):
return self.logpdf(x, remove_zerolag=remove_zerolag)
[docs]
def sample(self, size=None, random_state=None):
"""
Draw ``size`` samples from prior. Possible only if prior is proper.
Parameters
---------
size : int, default=None
Number of samples to draw.
If ``None``, return one sample (float).
random_state : int, numpy.random.Generator, numpy.random.RandomState, default=None
If integer, a new :class:`numpy.random.RandomState` instance is used, seeded with ``random_state``.
If ``random_state`` is a :class:`numpy.random.Generator` or :class:`numpy.random.RandomState` instance then that instance is used.
If ``None``, the :class:`numpy.random.RandomState` singleton is used.
Returns
-------
samples : float, array
Samples drawn from prior.
"""
if not self.is_proper():
raise ParameterPriorError('Cannot sample from improper prior')
return self.rv.rvs(size=size, random_state=random_state)
def __repr__(self):
"""String representation with distribution name, limits, and attributes (e.g. ``loc`` and ``scale``)."""
base = self.dist
if self.is_limited():
base = '{}[{}, {}]'.format(base, *self.limits)
return '{}({})'.format(base, self.attrs)
def __setstate__(self, state):
"""Set this class' state dictionary."""
self.__init__(**state)
def __getstate__(self):
"""Return this class' state dictionary."""
state = {'dist': self.dist, 'limits': self.limits}
state.update(self.attrs)
return state
[docs]
def is_proper(self):
"""Whether distribution is proper, i.e. has finite integral."""
return self.dist != 'uniform' or not np.isinf(self.limits).any()
[docs]
def is_limited(self):
"""Whether distribution has (at least one) finite limit."""
return not np.isinf(self.limits).all()
def center(self):
try:
center = self.loc
except AttributeError:
if self.is_limited():
center = np.mean([lim for lim in self.limits if not np.isinf(lim)])
else:
center = 0.
return center
def __getattr__(self, name):
"""Make :attr:`rv` attributes directly available in :class:`ParameterPrior`."""
try:
return getattr(object.__getattribute__(self, 'rv'), name)
except AttributeError as exc:
if object.__getattribute__(self, 'dist') == 'uniform':
raise AttributeError('uniform distribution has no {}'.format(name)) from exc
attrs = object.__getattribute__(self, 'attrs')
if name in attrs:
return attrs[name]
raise exc
def __eq__(self, other):
"""Is ``self`` equal to ``other``, i.e. same type and attributes?"""
return type(other) == type(self) and all(getattr(other, key) == getattr(self, key) for key in ['dist', 'limits', 'attrs'])
def _reshape(array, shape):
if np.ndim(shape) == 0:
shape = (shape,)
shape = tuple(shape)
try:
return array.reshape(shape + array.pshape)
except ValueError as exc:
raise ValueError('Error with array {}'.format(repr(array))) from exc
[docs]
@register_pytree_node_class
class Samples(BaseParameterCollection):
"""Class that holds samples, as a collection of :class:`ParameterArray`."""
_type = ParameterArray
_attrs = BaseParameterCollection._attrs + ['_derived']
_derived = []
def __init__(self, data=None, params=None, attrs=None):
"""
Initialize :class:`Samples`.
Parameters
----------
data : list, dict, Samples
Can be:
- list of :class:`ParameterArray`, or :class:`np.ndarray` if list of parameters
(or :class:`ParameterCollection`) is provided in ``params``
- dictionary mapping parameter to array
params : list, ParameterCollection
Optionally, list of parameters.
attrs : dict, default=None
Optionally, other attributes, stored in :attr:`attrs`.
"""
self.attrs = dict(attrs or {})
self.data = []
if params is not None:
if len(params) != len(data):
raise ValueError('Provide as many parameters as arrays')
for param, value in zip(params, data):
self[param] = value
else:
super(Samples, self).__init__(data=data, attrs=attrs)
[docs]
def save(self, filename):
"""Save samples to disk."""
filename = str(filename)
state = self.__getstate__()
for array in state['data']: array['value'] = np.asarray(array['value']) # could be jax
state = {'__class__': utils.serialize_class(self.__class__), **state}
self.log_info('Saving {}.'.format(filename))
utils.mkdir(os.path.dirname(filename))
if filename.endswith('.npz'):
statez = {'others': dict(state)}
statez['others'].pop('data', None)
statez['__class__'] = statez['others'].pop('__class__')
statez['params'] = []
for iarray, array in enumerate(state['data']):
statez['data.{:d}'.format(iarray)] = array['value']
statez['params'].append({key: value for key, value in array.items() if key not in ['value']})
np.savez(filename, **statez)
else:
np.save(filename, state, allow_pickle=True)
[docs]
@classmethod
def load(cls, filename):
"""Load samples from disk."""
filename = str(filename)
cls.log_info('Loading {}.'.format(filename))
state = np.load(filename, allow_pickle=True)
if filename.endswith('.npz'):
state = dict(state)
data = [{**param, 'value': state.pop('data.{:d}'.format(iarray))} for iarray, param in enumerate(state.pop('params')[()])]
if 'others' in state: # better as does not change type, e.g. _derived remains a list
others = {name: value for name, value in state['others'][()].items()}
else: # backward-compatibility
others = {name: value[()] for name, value in state.items()}
state = {**others, 'data': data}
else:
state = state[()]
state.pop('__class__', None)
new = cls.from_state(state)
return new
@staticmethod
def _get_param(item):
return item.param
@property
def shape(self):
"""Shape of samples."""
toret = ()
for array in self.data:
toret = array.ashape
break
return toret
@shape.setter
def shape(self, shape):
"""Set samples shape."""
self._reshape(shape)
def _reshape(self, shape):
for array in self.data:
super(Samples, self).set(_reshape(array, shape)) # this is to circumvent automatic reshaping of :meth:`Samples.set`
[docs]
def reshape(self, *args):
"""Reshape samples (with shallow copy)."""
new = self.copy()
if len(args) == 1:
shape = args[0]
else:
shape = args
new._reshape(shape)
return new
[docs]
def ravel(self):
"""Flatten samples."""
return self.reshape(self.size)
@property
def ndim(self):
"""Number of dimensions."""
return len(self.shape)
@property
def size(self):
"""Total number of samples."""
return np.prod(self.shape, dtype='intp')
def __len__(self):
"""Length of samples."""
if self.shape:
return self.shape[0]
return 0
[docs]
@classmethod
def concatenate(cls, *others, intersection=False):
"""
Concatenate input samples, which requires all samples to hold same parameters,
except if ``intersection == True``, in which case common parameters are selected.
"""
if len(others) == 1 and utils.is_sequence(others[0]):
others = others[0]
if not others: return cls()
new = others[0].copy()
new.data = []
new_params = others[0].params()
others = list(others[:1]) + [other for other in others[1:] if other.params() and other.size]
if intersection:
for other in others:
new_params &= other.params()
else:
new_names = new_params.names()
for other in others:
other_names = other.names()
if set(other_names) != set(new_names):
raise ValueError('cannot concatenate values as parameters do not match: {} != {}.'.format(new_names, other_names))
def atleast_1d(item):
shape = item.ashape
if not shape: shape = (1,)
item = _reshape(item, shape)
return item
for param in new_params:
try:
value = np.concatenate([atleast_1d(other[param]) for other in others], axis=0)
except ValueError as exc:
raise ValueError('error while concatenating array for parameter {}'.format(param)) from exc
new[param] = others[0][param].clone(value=value)
return new
[docs]
def update(self, *args, **kwargs):
"""
Update samples with new one; arguments can be a :class:`Samples`
or arguments to instantiate such a class (see :meth:`__init__`).
"""
if len(args) == 1 and isinstance(args[0], self.__class__):
other = args[0]
else:
other = self.__class__(*args, **kwargs)
for item in other:
self.set(item)
self.attrs.update(other.attrs)
[docs]
def set(self, item):
"""Add new :class:`ParameterArray` to samples."""
if self.data:
shape = self.shape
else:
shape = item.ashape
#if not shape: shape = (1,)
item = _reshape(item, shape)
super(Samples, self).set(item)
def __setitem__(self, name, item):
"""
Update array in samples.
Parameters
----------
name : Parameter, str, int
Parameter name.
If :class:`Parameter` instance, search for parameter with same name.
If integer, index in collection.
item : ParameterArray, array
Array.
"""
if not isinstance(item, self._type):
try:
name = self.data[name].param # list index
except TypeError:
pass
is_derived = str(name) in self._derived
if isinstance(name, Parameter):
param = name
else:
param = Parameter(name, latex=utils.outputs_to_latex(str(name)) if is_derived else None, derived=is_derived)
if param in self:
param = param.clone(self[param].param)
item = ParameterArray(item, param)
try:
self.data[name] = item # list index
except TypeError:
item_name = self._get_name(item)
if str(name) != item_name:
item = item.copy()
if isinstance(name, Parameter):
item.param = name
else:
if item.param is None:
item.param = Parameter(name)
item.param = item.param.clone(name=name)
#raise KeyError('Parameter {} must be indexed by name (incorrect {})'.format(item_name, name))
self.set(item)
def __getitem__(self, name):
"""
Get samples parameter ``name`` if :class:`Parameter` or string,
else return copy with local slice of samples.
"""
if isinstance(name, (Parameter, str)):
return super().__getitem__(name)
new = self.copy()
try:
index = [name] if not isinstance(name, slice) and np.ndim(name) == 0 else name
new.data = [column[index] for column in self.data]
except IndexError as exc:
raise IndexError('Unrecognized indices {}'.format(name)) from exc
return new
def __repr__(self):
"""Return string representation, including shape and parameters."""
return '{}(shape={}, params={})'.format(self.__class__.__name__, self.shape, self.params())
[docs]
def to_array(self, params=None, struct=True, derivs=None):
"""
Return samples as numpy array.
Parameters
----------
params : ParameterCollection, list, default=None
Parameters to use. Defaults to all parameters.
struct : bool, default=True
Whether to return structured array, with columns accessible through e.g. ``array['x']``.
If ``False``, numpy will attempt to cast types of different columns.
Returns
-------
array : array
"""
if params is None: params = self.params()
names = [str(param) for param in params]
values = []
for name in names:
value = self[name]
if derivs is not None:
value = value[derivs]
values.append(value)
if struct:
toret = np.empty(self.shape, dtype=[(name, value.dtype, value.shape[len(self.shape):]) for value in values])
for name, value in zip(names, values): toret[name] = value
return toret
return np.array(values)
[docs]
def to_dict(self, params=None):
"""
Return samples as a dictionary.
Parameters
----------
params : ParameterCollection, list, default=None
Parameters to use. Defaults to all parameters.
Returns
-------
dict : dict
Dictionary mapping parameter name to array.
"""
if params is None: params = self.params()
return {str(param): self[param] for param in params}
[docs]
def match(self, other, eps=1e-7, params=None):
"""
Match other :class:`Samples` against ``self``, for parameters ``params``.
Parameters
----------
other : Samples
Samples to match.
eps : float, default=1e-7
Distance upper bound above which samples are not considered equal.
1e-7 to handle float32/float64 conversions.
params : ParameterCollection, list, default=None
Parameters to use. Defaults to all parameters that are not derived.
Returns
-------
index_in_other, index_in_self
"""
if params is None:
params = set(self.names(derived=False)) & set(other.names(derived=False))
from scipy import spatial
kdtree = spatial.cKDTree(np.column_stack([self[name].ravel() for name in params]), leafsize=16, compact_nodes=True, copy_data=False, balanced_tree=True, boxsize=None)
array = np.column_stack([other[name].ravel() for name in params])
dist, indices = kdtree.query(array, k=1, eps=0, p=2, distance_upper_bound=eps)
mask = indices < self.size
return np.unravel_index(np.flatnonzero(mask), shape=other.shape), np.unravel_index(indices[mask], shape=self.shape)
[docs]
@classmethod
@CurrentMPIComm.enable
def bcast(cls, value, mpicomm=None, mpiroot=0):
"""Broadcast input samples ``value`` from rank ``mpiroot`` to other processes."""
state = None
if mpicomm.rank == mpiroot:
state = value.__getstate__()
state['data'] = [(array['param'], array['derivs']) for array in state['data']]
state = mpicomm.bcast(state, root=mpiroot)
for ivalue, (param, derivs) in enumerate(state['data']):
state['data'][ivalue] = {'value': mpi.bcast(value.data[ivalue] if mpicomm.rank == mpiroot else None, mpicomm=mpicomm, mpiroot=mpiroot), 'param': param, 'derivs': derivs}
return cls.from_state(state)
[docs]
@CurrentMPIComm.enable
def send(self, dest, tag=0, mpicomm=None):
"""Send ``self`` to rank ``dest``."""
state = self.__getstate__()
state['data'] = [(array['param'], array['derivs']) for array in state['data']]
mpicomm.send(state, dest=dest, tag=tag)
for array in self:
mpi.send(array, dest=dest, tag=tag, mpicomm=mpicomm)
[docs]
@classmethod
@CurrentMPIComm.enable
def recv(cls, source=mpi.ANY_SOURCE, tag=mpi.ANY_TAG, mpicomm=None):
"""Receive samples from rank ``source``."""
state = mpicomm.recv(source=source, tag=tag)
for ivalue, (param, derivs) in enumerate(state['data']):
state['data'][ivalue] = {'value': mpi.recv(source, tag=tag, mpicomm=mpicomm), 'param': param, 'derivs': derivs}
return cls.from_state(state)
[docs]
@classmethod
@CurrentMPIComm.enable
def sendrecv(cls, value, source=0, dest=0, tag=0, mpicomm=None):
"""Send samples from rank ``source`` to rank ``dest`` and receive them here."""
if dest == source == mpicomm.rank:
return value.deepcopy()
if mpicomm.rank == source:
value.send(dest=dest, tag=tag, mpicomm=mpicomm)
toret = None
if mpicomm.rank == dest:
toret = cls.recv(source=source, tag=tag, mpicomm=mpicomm)
return toret
#def tree_flatten(self):
# data = [array.value for array in self.data]
# param_derivs = [(array.param, array.derivs) for array in self.data]
# return tuple(data), (param_derivs, {name: getattr(self, name) for name in self._attrs})
#@classmethod
#def tree_unflatten(cls, aux_data, children):
# new = cls([ParameterArray(value, *param_derivs) for value, param_derivs in zip(children, aux_data[0])])
# for name, value in aux_data[1].items():
# setattr(new, name, value)
# return new
def tree_flatten(self):
return self.data, {name: getattr(self, name) for name in self._attrs}
@classmethod
def tree_unflatten(cls, aux_data, children):
new = cls()
new.data = children
for name, value in aux_data.items():
setattr(new, name, value)
return new
def is_parameter_sequence(params):
# ``True`` if ``params`` is a sequence of parameters."""
return isinstance(params, ParameterCollection) or utils.is_sequence(params)
[docs]
class BaseParameterMatrix(BaseClass):
_fill_value = np.nan
"""Base class representing a parameter matrix."""
def __init__(self, value, params=None, attrs=None):
"""
Initialize :class:`BaseParameterMatrix`.
Parameters
----------
value : array
2D array representing matrix.
params : list, ParameterCollection
Parameters corresponding to input ``value``.
attrs : dict, default=None
Optionally, other attributes, stored in :attr:`attrs`.
"""
if isinstance(value, self.__class__):
self.__dict__.update(value.__dict__)
return
if params is None:
raise ValueError('Provide matrix parameters')
self._params = ParameterCollection(params)
self._params = ParameterCollection([param.clone(fixed=False) for param in self._params])
if not self._params:
raise ValueError('Got no parameters')
if getattr(value, 'derivs', None) is not None:
value = np.array([[value[param1, param2] for param2 in self._params] for param1 in self._params])
self._value = np.atleast_2d(np.array(value))
if self._value.ndim != 2:
raise ValueError('Input matrix must be 2D')
shape = self._value.shape
if shape[1] != shape[0]:
raise ValueError('Input matrix must be square')
if shape[0] != len(self._params):
raise ValueError('Number of parameters and matrix size are different: {:d} vs {:d}'.format(len(self._params), shape[0]))
self._sizes
self.attrs = dict(attrs or {})
[docs]
def params(self, *args, **kwargs):
"""Return parameters in matrix."""
return self._params.params(*args, **kwargs)
[docs]
def names(self, *args, **kwargs):
"""Return names of parameters in matrix."""
return self._params.names(*args, **kwargs)
[docs]
def select(self, params=None, **kwargs):
"""
Return a sub-matrix.
Parameters
----------
params : list, ParameterCollection, default=None
Optionally, parameters to limit to.
**kwargs : dict
If ``params`` is ``None``, optional arguments passed to :meth:`ParameterCollection.select`
to select parameters (e.g. ``varied=True``).
Returns
-------
new : BaseParameterMatrix
A sub-matrix.
"""
if params is None: params = self._params.select(**kwargs)
return self.view(params=params, return_type=None)
[docs]
def det(self, params=None):
"""Return matrix determinant, limiting to input parameters ``params`` if not ``None``."""
return np.linalg.det(self.view(params=params, return_type='nparray'))
@property
def _sizes(self):
# Return parameter sizes
toret = [max(param.size, 1) for param in self._params]
if sum(toret) != self._value.shape[0]:
raise ValueError('number * size of input params must match input matrix shape')
return toret
[docs]
def clone(self, value=None, params=None, attrs=None):
"""
Clone this matrix, i.e. copy and optionally update ``value``, ``params`` and ``attrs``.
Parameters
----------
value : array, default=None
2D array to replace matrix value with.
params : list, ParameterCollection, default=None
New parameters.
attrs : dict, default=None
Optionally, other attributes, stored in :attr:`attrs`.
Returns
-------
new : BaseParameterMatrix
A new matrix, optionally with ``value`` and ``params`` updated.
"""
new = self.view(params=params, return_type=None)
if value is not None:
new._value[...] = value
if attrs is not None:
new.attrs = dict(attrs)
return new
[docs]
def view(self, params=None, return_type=None):
"""
Return matrix for input parameters ``params``.
Parameters
----------
params : list, ParameterCollection, default=None
If provided, restrict to these parameters.
If a single parameter is provided, and this parameter is a scalar, return a scalar.
If a parameter in ``params`` is not in matrix, add it, filling in the returned matrix with zeros,
except on the diagonal, which is filled with :attr:`_fill_value`.
return_type : str, default=None
If 'nparray', return a numpy array.
Else, return a new :class:`BaseParameterMatrix`, restricting to ``params``.
Returns
-------
new : array, float, BaseParameterMatrix
"""
if params is None:
params = self._params
isscalar = not is_parameter_sequence(params)
if isscalar:
params = [params]
params = [self._params[param] if param in self._params else Parameter(param) for param in params]
params_in_self = [param for param in params if param in self._params]
params_not_in_self = [param for param in params if param not in params_in_self]
sizes = [max(param.size, 1) for param in params]
new = self.__class__(np.zeros((sum(sizes),) * 2, dtype='f8'), params=params, attrs=self.attrs)
if params_in_self:
index_new, index_self = new._index(params_in_self), self._index(params_in_self)
new._value[np.ix_(index_new, index_new)] = self._value[np.ix_(index_self, index_self)]
if params_not_in_self:
index_new = new._index(params_not_in_self)
new._value[np.ix_(index_new, index_new)] = self._fill_value
if return_type == 'nparray':
new = new._value
if isscalar:
new.shape = params[0].shape
return new
return new
def __array__(self, *args, **kwargs):
return np.asarray(self._value, *args, **kwargs)
def _index(self, params):
# Internal method to return indices in matrix array corresponding to input params.""""
cumsizes = np.cumsum([0] + self._sizes)
idx = [self._params.index(param) for param in params]
if idx:
return np.concatenate([np.arange(cumsizes[ii], cumsizes[ii + 1]) for ii in idx], dtype='i4')
return np.array(idx, dtype='i4')
def __contains__(self, name):
"""Has this parameter?"""
return name in self._params
def __getstate__(self):
"""Return this class' state dictionary."""
state = {}
state['value'] = self._value
state['params'] = self._params.__getstate__()
state['attrs'] = self.attrs
return state
def __setstate__(self, state):
"""Set this class' state dictionary."""
self._params = ParameterCollection.from_state(state['params'])
self._value = state['value']
self.attrs = state.get('attrs', {})
def __repr__(self):
"""Return string representation of parameter matrix, including parameters."""
return '{}({})'.format(self.__class__.__name__, self._params)
def __eq__(self, other):
"""Is ``self`` equal to ``other``, i.e. same type and attributes?"""
return type(other) == type(self) and all(deep_eq(getattr(other, name), getattr(self, name)) for name in ['_params', '_value'])
[docs]
@classmethod
@CurrentMPIComm.enable
def bcast(cls, value, mpicomm=None, mpiroot=0):
"""Broadcast input samples ``value`` from rank ``mpiroot`` to other processes."""
state = None
if mpicomm.rank == mpiroot:
state = value.__getstate__()
state['value'] = None
state = mpicomm.bcast(state, root=mpiroot)
state['value'] = mpi.bcast(value._value if mpicomm.rank == mpiroot else None, mpicomm=mpicomm, mpiroot=mpiroot)
return cls.from_state(state)
[docs]
def deepcopy(self):
"""Deep copy"""
return copy.deepcopy(self)
def __mul__(self, other):
"""Multiply matrix by ``other`` (typically, a float)."""
new = self.deepcopy()
new._value *= other
return new
def __rmul__(self, other):
return self.__mul__(other)
def __truediv__(self, other):
"""Divide matrix by ``other`` (typically, a float)."""
new = self.deepcopy()
new._value /= other
return new
def __rtruediv__(self, other):
return self.__truediv__(other)
@property
def shape(self):
"""Return matrix shape."""
return self._value.shape
[docs]
class ParameterCovariance(BaseParameterMatrix):
"""Class that represents a parameter covariance matrix."""
[docs]
def view(self, params=None, return_type='nparray', fill=None):
"""
Return matrix for input parameters ``params``.
Parameters
----------
params : list, ParameterCollection, default=None
If provided, restrict to these parameters.
If a single parameter is provided, this parameter is a scalar, and ``return_type`` is 'nparray', return a scalar.
If a parameter in ``params`` is not in matrix, add it, filling in the returned matrix with zeros,
except on the diagonal, which is filled with :attr:`Parameter.proposal`.
return_type : str, default='nparray'
If 'nparray', return a numpy array.
Else, return a new :class:`ParameterCovariance`, restricting to ``params``.
Returns
-------
new : array, float, ParameterCovariance
"""
new = super(ParameterCovariance, self).view(params=params, return_type=None)
if fill == 'proposal':
params_not_in_self = [param for param in new._params if param not in self._params and param.proposal is not None]
index = new._index(params_not_in_self)
new._value[index, index] = [param.proposal**2 for param in params_not_in_self]
toret = super(ParameterCovariance, new).view(return_type=return_type)
if return_type == 'nparray' and params is not None and not is_parameter_sequence(params):
toret.shape = new._params[params].shape
return toret
[docs]
def fom(self, **params):
"""Figure-of-Merit, as inverse square root of the matrix determinant (optionally restricted to input parameters)."""
return self.det(**params)**(-0.5)
[docs]
def corrcoef(self, params=None):
"""Return correlation matrix array (optionally restricted to input parameters)."""
return utils.cov_to_corrcoef(self.view(params=params, return_type='nparray'))
[docs]
def var(self, params=None):
"""
Return variance (optionally restricted to input parameters).
If a single parameter is given as input and this parameter is a scalar, return a scalar.
"""
cov = self.view(params=params, return_type='nparray')
if np.ndim(cov) == 0: return cov # single param
return np.diag(cov)
[docs]
def std(self, params=None):
"""
Return standard deviation (optionally restricted to input parameters).
If a single parameter is given as input and this parameter is a scalar, return a scalar.
"""
return self.var(params=params)**0.5
[docs]
def to_precision(self, params=None, return_type=None):
"""
Return inverse covariance matrix (precision matrix) for input parameters ``params``.
Parameters
----------
params : list, ParameterCollection, default=None
If provided, restrict to these parameters.
If a single parameter is provided, this parameter is a scalar, and ``return_type`` is 'nparray', return a scalar.
return_type : str, default=None
If 'nparray', return a numpy array.
Else, return a new :class:`ParameterPrecision`.
Returns
-------
new : array, float, ParameterPrecision
"""
if params is None: params = self._params
view = self.view(params, return_type=None)
invcov = utils.inv(view._value)
if return_type == 'nparray':
return invcov
return ParameterPrecision(invcov, params=params, attrs=view.attrs)
[docs]
def to_stats(self, params=None, sigfigs=2, tablefmt='latex_raw', fn=None):
"""
Export covariance matrix to string.
Parameters
----------
params : list, ParameterCollection, default=None
If provided, restrict to these parameters.
sigfigs : int, default=2
Number of significant digits.
See :func:`utils.round_measurement`.
tablefmt : str, default='latex_raw'
Format for summary table.
See :func:`tabulate.tabulate`.
fn : str, default=None
If not ``None``, file name where to save summary table.
Returns
-------
txt : str
Summary table.
"""
import tabulate
is_latex = 'latex_raw' in tablefmt
view = self.view(params, return_type=None)
headers = [param.latex(inline=True) if is_latex else str(param) for param in view._params]
data = [[str(param)] + [utils.round_measurement(value, value, sigfigs=sigfigs)[0] for value in row] for param, row in zip(view._params, view._value)]
txt = tabulate.tabulate(data, headers=headers, tablefmt=tablefmt)
if fn is not None:
utils.mkdir(os.path.dirname(fn))
self.log_info('Saving to {}.'.format(fn))
with open(fn, 'w') as file:
file.write(txt)
return txt
[docs]
def to_getdist(self, params=None, label=None, center=None, ignore_limits=True):
"""
Return a GetDist Gaussian distribution, with covariance matrix :meth:`cov`.
Parameters
----------
params : list, ParameterCollection, default=None
Parameters to share to GetDist. Defaults to all parameters.
label : str, default=None
Name for GetDist to use for this distribution.
center : list, array, default=None
Optionally, override :attr:`Parameter.value`.
ignore_limits : bool, default=True
GetDist does not seem to be able to integrate over distribution if bounded;
so drop parameter limits.
Returns
-------
samples : getdist.gaussian_mixtures.MixtureND
"""
from getdist.gaussian_mixtures import MixtureND
cov = self.view(params=params, return_type=None)
labels = [param.latex() for param in cov._params]
names = [str(param) for param in cov._params]
# ignore_limits to avoid issue in GetDist with analytic marginalization
ranges = None
if not ignore_limits:
ranges = [tuple(None if limit is None or not np.isfinite(limit) else limit for limit in param.prior.limits) for param in cov._params]
center = np.asarray([param.value for param in cov._params]) if center is None else np.asarray(center)
return MixtureND([center], [cov._value], lims=ranges, names=names, labels=labels, label=label)
[docs]
@classmethod
def read_getdist(cls, base_fn):
"""
Read covariance matrix from GetDist format.
Parameters
----------
base_fn : str
Base *CosmoMC* file name. Will be appended by '.margestats' for marginalized parameter mean,
'.covmat' for parameter covariance matrix.
Returns
-------
cov : CovarianceMatrix
"""
covmat_fn = '{}.covmat'.format(base_fn)
cls.log_info('Loading covariance file: {}.'.format(covmat_fn))
covariance = []
with open(covmat_fn, 'r') as file:
for line in file:
line = [item.strip() for item in line.split()]
if line:
if line[0] == '#':
params = line[1:]
else:
covariance.append([float(value) for value in line])
return cls(covariance, params=[Parameter(param, fixed=False) for param in params])
[docs]
class ParameterPrecision(BaseParameterMatrix):
_fill_value = 0.
[docs]
def fom(self, **params):
"""Figure-of-Merit, as square root of the precision (optionally restricted to input parameters)."""
return self.to_covariance().fom(params=params)
[docs]
def to_covariance(self, params=None, return_type=None):
"""
Return inverse precision matrix (covariance matrix) for input parameters ``params``.
Parameters
----------
params : list, ParameterCollection, default=None
If provided, restrict to these parameters.
If a single parameter is provided, this parameter is a scalar, and ``return_type`` is 'nparray', return a scalar.
return_type : str, default=None
If 'nparray', return a numpy array.
Else, return a new :class:`ParameterCovariance`.
Returns
-------
new : array, float, ParameterCovariance
"""
cov = utils.inv(self._value)
return ParameterCovariance(cov, params=self._params, attrs=self.attrs).view(params=params, return_type=return_type)
[docs]
@classmethod
def sum(cls, *others):
"""Add precision matrices."""
if len(others) == 1 and utils.is_sequence(others[0]):
others = others[0]
params = ParameterCollection.concatenate([other._params for other in others])
new = others[0].view(params, return_type=None)
for other in others[1:]:
other = other.view(new._params, return_type=None)
new._value += other._value
new.attrs.update(other.attrs)
return new
def __add__(self, other):
"""Sum of `self`` + ``other`` precision matrices."""
return self.sum(self, other)
def __radd__(self, other):
if other == 0: return self.deepcopy()
return self.__add__(other)