import itertools
import numpy as np
from desilike import PipelineError
from .base import _params_args_or_kwargs
from .parameter import Parameter, ParameterCollection, ParameterArray, Samples, Deriv, ParameterPriorError
from .utils import BaseClass, expand_dict, is_sequence
from .jax import jax
from .jax import numpy as jnp
from . import mpi
[docs]
def deriv_ncoeffs(order, acc=2):
"""Return number of coefficients given input derivative order and accuracy."""
return 2 * ((order + 1) // 2) - 1 + acc
[docs]
def coefficients(order, acc, coords, idx):
"""
Calculate the finite difference coefficients for given derivative order and accuracy order.
Assume that the underlying grid is non-uniform.
Adapted from https://github.com/maroba/findiff/blob/master/findiff/coefs.py
Parameters
----------
order : int
The derivative order (positive integer).
acc : int
The accuracy order (even positive integer).
coords : np.ndarray
The coordinates of the axis for the partial derivative.
idx : int
Index of the grid position where to calculate the coefficients.
Returns
-------
coeffs, offsets
"""
import math
if acc % 2 or acc <= 0:
raise ValueError('Accuracy order acc must be positive EVEN integer')
if order < 0:
raise ValueError('Derive degree must be positive integer')
order, acc = int(order), int(acc)
ncoeffs = deriv_ncoeffs(order, acc=acc)
nside = ncoeffs // 2
ncoeffs += (order % 2 == 0)
def _build_rhs(offsets, order):
"""The right hand side of the equation system matrix"""
b = [0 for _ in offsets]
b[order] = math.factorial(order)
return np.array(b, dtype='float')
def _build_matrix_non_uniform(p, q, coords, k):
"""Constructs the equation matrix for the finite difference coefficients of non-uniform grids at location k"""
A = [[1] * (p + q + 1)]
for i in range(1, p + q + 1):
line = [(coords[k + j] - coords[k])**i for j in range(-p, q + 1)]
A.append(line)
return np.array(A, dtype='float')
if idx < nside:
matrix = _build_matrix_non_uniform(0, ncoeffs - 1, coords, idx)
offsets = list(range(ncoeffs))
rhs = _build_rhs(offsets, order)
return np.linalg.solve(matrix, rhs), np.array(offsets)
if idx >= len(coords) - nside:
matrix = _build_matrix_non_uniform(ncoeffs - 1, 0, coords, idx)
offsets = list(range(-ncoeffs + 1, 1))
rhs = _build_rhs(offsets, order)
return np.linalg.solve(matrix, rhs), np.array(offsets)
matrix = _build_matrix_non_uniform(nside, nside, coords, idx)
offsets = list(range(-nside, nside + 1))
rhs = _build_rhs(offsets, order)
return np.linalg.solve(matrix, rhs), np.array([p for p in range(-nside, nside + 1)])
[docs]
def deriv_nd(X, Y, orders, center=None, atol=0.):
"""
Compute n-dimensional derivative.
Parameters
----------
X : array
Array of shape (nsamples, ndim), with ndim the number of variables.
Y : array
Array of shape (nsamples, ysize), with ysize the size of the vector to derive.
orders : list
List of tuples (derivation axis between 0 and ndim - 1, derivative order, derivative accuracy).
center : array, default=None
The center around which to take derivatives, of size ndim.
If ``None``, defaults to the median of input ``X``.
atol : list, float
Absolute tolerance to find the center.
Returns
-------
deriv : array
Derivative of Y, of size ysize.
"""
uorders = []
for axis, order, acc in orders:
if not order: continue
uorders.append((axis, order, acc))
orders = uorders
if center is None:
center = [np.median(np.unique(xx)) for xx in X.T]
if np.ndim(atol) == 0:
atol = [atol] * X.shape[1]
atol = list(atol)
if not len(orders):
toret = Y[np.all([np.isclose(xx, cc, rtol=0., atol=at) for xx, cc, at in zip(X.T, center, atol)], axis=0)]
if not toret.size:
raise ValueError('Global center point not found')
return toret[0]
axis, order, acc = orders[-1]
ncoeffs = deriv_ncoeffs(order, acc=acc)
coord = np.unique(X[..., axis])
if coord.size < ncoeffs:
raise ValueError('Grid is not large enough ({:d} < {:d}) to estimate {:d}-th order derivative'.format(coord.size, ncoeffs, order))
cidx = np.flatnonzero(np.isclose(coord, center[axis], rtol=0., atol=atol[axis]))
if not cidx.size:
raise ValueError('Global center point not found')
cidx = cidx[0]
toret = 0.
for coeff, offset in zip(*coefficients(order, acc, coord, cidx)):
mask = X[..., axis] == coord[cidx + offset]
ncenter = center.copy()
ncenter[axis] = coord[cidx + offset]
# We could fill in atol[axis] = 0., but it should be useless?
y = deriv_nd(X[mask], Y[mask], orders[:-1], center=ncenter, atol=atol)
toret += y * coeff
return toret
[docs]
def deriv_grid(grids, current_order=0):
"""
Return grid of points where to compute function to estimate its derivatives.
Parameters
----------
grids : list
List of tuples (1D grid coordinates, array of (minimum) derivative orders corresponding to 1D grid, derivative accuracy).
Returns
-------
grid : list
List of coordinates.
"""
grid, orders, maxorder = grids[-1]
toret = []
for order in np.unique(orders)[::-1]:
if order == 0 or order + current_order <= maxorder:
mask = orders == order
if len(grids) > 1:
mgrid = deriv_grid(grids[:-1], current_order=order + current_order)
else:
mgrid = [[]]
toret += [mg + [gg] for mg in mgrid for gg in grid[mask]]
return toret
[docs]
class Differentiation(BaseClass):
"""Estimate derivatives of ``calculator`` quantities, with auto- or finite-differentiation."""
def __init__(self, calculator, getter=None, order=1, method=None, accuracy=2, delta_scale=1., mpicomm=None):
"""
Initialize differentiation.
Parameters
----------
calculator : BaseCalculator
Input calculator.
getter : callable, default=None
Function (without input arguments) that returns a quantity,
or a list or dictionary mapping names to quantities from ``calculator`` to be differentiated.
If ``None``, defaults to derived parameters.
order : int, dict, default=1
A dictionary mapping parameter name (including wildcard) to maximum derivative order.
If a single value is provided, applies to all varied parameters.
method : str, dict, default=None
A dictionary mapping parameter name (including wildcard) to method to use to estimate derivatives,
either 'auto' for automatic differentiation, or 'finite' for finite differentiation.
If ``None``, 'auto' will be used if possible, else 'finite'.
If a single value is provided, applies to all varied parameters.
accuracy : int, dict, default=2
A dictionary mapping parameter name (including wildcard) to derivative accuracy (number of points used to estimate it).
If a single value is provided, applies to all varied parameters.
Not used if ``method = 'auto'`` for this parameter.
delta_scale : float, default=1.
Parameter grid ranges for the estimation of finite derivatives are inferred from parameters' :attr:`Parameter.delta`.
These values are then scaled by ``delta_scale`` (< 1. means smaller ranges).
mpicomm : mpi.COMM_WORLD, default=None
MPI communicator. If ``None``, defaults to ``calculator``'s :attr:`BaseCalculator.mpicomm`.
"""
if mpicomm is None:
mpicomm = calculator.mpicomm
self.mpicomm = mpicomm
self.calculator = calculator
self.calculator() # dry run
self.pipeline = self.calculator.runtime_info.pipeline
self.varied_params = self.calculator.varied_params
# In case of likelihood marginalization self.calculator.runtime_info.pipeline._varied_params is changed
# Make sure these parameters are included in all_params with + self.varied_params
self.all_params = self.calculator.all_params.select(derived=False) + self.varied_params
if not self.varied_params:
raise ValueError('No parameters to be varied!')
if mpicomm.rank == 0:
self.log_info('Varied parameters: {}.'.format(self.varied_params.names()))
for name, item in zip(['order', 'method', 'accuracy'], [order, method, accuracy]):
setattr(self, name, expand_dict(item, self.varied_params.names()))
for param, value in self.order.items():
if value is None: value = 0
self.order[param] = int(value)
self.getter = getter
if getter is None:
calculators, fixed, varied = self.pipeline._classify_derived(self.pipeline.calculators)
varied_by_calculator = []
for cc, vv in zip(calculators, varied):
base_names = cc.runtime_info.base_names
tmp = ParameterCollection()
for v in vv:
if v in base_names:
p = self.pipeline.params[base_names[v]]
if p.derived: tmp.set(p.copy())
varied_by_calculator.append(tmp)
if not any(varied_by_calculator):
raise ValueError('No varied parameter is derived, so nothing to differentiate')
def getter():
toret = {}
for calculator, varied in zip(calculators, varied_by_calculator):
state = calculator.__getstate__()
for param in varied:
name = param.basename
if name in state: value = state[name]
else: value = getattr(calculator, name)
toret[param] = value = jnp.array(value)
param._shape = value.shape # a bit hacky, but no need to update parameters for this...
return toret
self.getter = getter
for param, method in self.method.items():
if self.order[param] == 0:
method = self.method[param] = 'auto'
continue
if method in [None, 'auto']:
try:
self._calculate({param: [self.pipeline.input_values[param]]}, autoderivs=[(), (param,)]) # This takes time because the model is evaluated for each parameter
except Exception as exc:
if method is None:
method = 'finite'
else:
raise ValueError('Cannot use auto-differentiation (with jax) for parameter {}'.format(param)) from exc
else:
method = 'auto'
if self.method[param] is None and mpicomm.rank == 0:
self.log_info('Using {}-differentiation for parameter {}.'.format(method, param))
self.method[param] = method
if method == 'finite':
value = self.accuracy[param]
if value is None:
raise ValueError('accuracy not specified for parameter {}'.format(param))
value = int(value)
if value < 1:
raise ValueError('accuracy is {} < 1 for parameter {}'.format(value, param))
if value % 2:
raise ValueError('accuracy is {} for parameter {}, but it must be a positive EVEN integer'.format(value, param))
self.accuracy[param] = value
self._grid_center, grids = {}, []
for param in self.varied_params:
center = param.delta[0]
if self.method[param.name] == 'finite' and self.order[param.name]:
size = deriv_ncoeffs(self.order[param.name], acc=self.accuracy[param.name])
delta, limits = param.delta[1:], param.prior.limits
if not (limits[0] <= center <= limits[-1]):
raise ValueError('for {} center {} is not within prior limits {}'.format(param.name, center, limits))
delta = tuple(delta_scale * dd for dd in delta)
if any(dd <= 0 for dd in delta):
raise ValueError('for {} delta {} is not > 0'.format(param.name, delta))
hsize = size // 2
grid_min = limits[1] - hsize * (delta[0] + delta[1]) # if we start from upper limit
grid_min = max(limits[0], min(center - delta[0] * hsize, grid_min))
grid = [grid_min + np.arange(hsize + 1) * delta[0]] # below center
center = grid[0][-1]
grid.append(center + np.arange(1, hsize + 1) * delta[1]) # above center
grid = np.concatenate(grid)
if grid[-1] > limits[1]:
raise ValueError('for {}, cannot fit {:d} steps in prior limits {} with delta = {}; increase prior limits or decrease delta'.format(param.name, size, limits, delta))
cindex = hsize
order = np.zeros(len(grid), dtype='i')
for ord in range(self.order[param.name], 0, -1):
s = deriv_ncoeffs(ord, acc=self.accuracy[param.name])
order[cindex - s // 2:cindex + s // 2 + 1] = ord
order[cindex] = 0
grid = (grid, order, self.order[param.name])
if mpicomm.rank == 0:
self.log_info('{} grid is {}.'.format(param, grid[0]))
else:
grid = (np.array([center]), np.array([0]), 0)
self._grid_center[param.name] = center
grids.append(grid)
self._grid_samples = self._grid_cidx = None
if mpicomm.rank == 0:
samples = np.array(deriv_grid(grids)).T
self._grid_samples = Samples(samples, params=self.varied_params)
self._grid_cidx = True
for array, grid in zip(self._grid_samples, grids):
grid = grid[0]
center = grid[len(grid) // 2]
atol = 0.
self._grid_cidx &= np.isclose(array, center, rtol=0., atol=atol)
self._grid_cidx = tuple(np.flatnonzero(self._grid_cidx))
assert len(self._grid_cidx) == 1
self.log_info('Differentiation will evaluate {:d} points.'.format(len(self._grid_samples)))
self._grid_cidx = mpicomm.bcast(self._grid_cidx, root=0)
autoparams, autoorder, self.autoderivs = [], [], []
for param, method in self.method.items():
autoparams.append(param)
autoorder.append(self.order[param] if method == 'auto' else 0)
self.autoderivs.append(())
for maxorder in range(1, max([0] + autoorder) + 1):
self.autoderivs.append([autoparams[i] for i, o in enumerate(autoorder) if o >= maxorder])
#self.mpicomm = mpicomm
@property
def mpicomm(self):
return self._mpicomm
@mpicomm.setter
def mpicomm(self, mpicomm):
mpicomm_bak = getattr(self, '_mpicomm', None)
if mpicomm_bak is not None and mpicomm is not mpicomm_bak:
# Broadcast self._grid_samples to the new rank = 0 processes
ranks = mpicomm_bak.allgather(mpicomm_bak.rank if mpicomm.rank == 0 else None)
ranks = [rank for rank in ranks if rank is not None]
for mpiroot in ranks:
grid_samples = Samples.bcast(self._grid_samples, mpiroot=mpiroot, mpicomm=mpicomm_bak)
if grid_samples is not None: self._grid_samples = grid_samples
self._mpicomm = mpicomm
def _calculate(self, params, autoderivs=None):
if autoderivs is None:
autoderivs = self.autoderivs
mpicomm = self.pipeline.mpicomm
names = self.mpicomm.bcast(list(params) if self.mpicomm.rank == 0 else None, root=0)
values = []
for name in names:
value = np.atleast_1d(params[name]) if self.mpicomm.rank == 0 else None
values.append(value)
csize = self.mpicomm.bcast(value.size if self.mpicomm.rank == 0 else None)
global getter_inst, getter_size
getter_inst, getter_size = None, None
def __calculate(*values):
global getter_inst, getter_size
assert len(names) == len(values)
self.pipeline.calculate(dict(zip(names, values)))
toret = self.getter()
getter_inst = toret
if hasattr(toret, 'values'):
toret = list(toret.values())
getter_size = int(is_sequence(toret))
if getter_size:
getter_size = len(toret)
else:
toret = [toret]
toret = list(toret)
if not toret:
raise ValueError('getter returns nothing to differentiate')
return toret
getter_samples = []
max_chunk_size = getattr(self, '_mpi_max_chunk_size', 100)
nchunks = (csize // max_chunk_size) + 1
import traceback
for ichunk in range(nchunks): # divide in chunks to save memory for MPI comm
self.pipeline.mpicomm = mpi.COMM_SELF
chunk_params = {}
for name, value in zip(names, values):
chunk_params[name] = mpi.scatter(value[csize * ichunk // nchunks:csize * (ichunk + 1) // nchunks] if self.mpicomm.rank == 0 else None, mpicomm=self.mpicomm, mpiroot=0)
chunk_size = len(chunk_params[name])
tmp_samples, errors = [], []
for ivalue in range(chunk_size):
chunk_values = [chunk_params[name][ivalue] for name in chunk_params]
tmp_i_samples = []
try:
try:
jac = __calculate
for iautoderiv, autoderiv in enumerate(autoderivs[1:]):
if jax is None:
raise ValueError('jax is required to compute the Jacobian')
argnums = [names.index(p) for p in autoderiv]
funcname = 'jacfwd' # if iautoderiv else 'jacrev'
jac = getattr(jax, funcname)(jac, argnums=argnums, has_aux=False, holomorphic=False)
#jac = jax.jacfwd(jac, argnums=argnums, has_aux=False, holomorphic=False)
tmp_i_samples.append(jac(*chunk_values))
#jax.vjp(tmp, has_aux=False)[1](jnp.ones(len(autoderiv)))
except Exception as exc:
raise exc
finally:
tmp_samples.append([__calculate(*chunk_values)] + tmp_i_samples)
#print(ivalue, chunk_values, tmp_samples[-1][0])
except Exception as exc:
errors.append((exc, traceback.format_exc()))
errors = self.mpicomm.allreduce(errors)
self.pipeline.mpicomm = mpicomm
if errors:
raise PipelineError('got these errors: {}'.format(errors))
tmp_samples = self.mpicomm.reduce(tmp_samples, root=0)
if self.mpicomm.rank == 0:
getter_samples += tmp_samples
toret = None
for getter_size, getter_inst in (self.mpicomm.gather((getter_size, getter_inst), root=0) or []):
if getter_size is not None: break
if self.mpicomm.rank == 0:
toret = [[[None for isample in range(csize)] for iautoderiv in range(len(autoderivs))] for igetter in range(max(getter_size, 1))]
for isample in range(csize):
items = getter_samples[isample]
for ideriv, derivs in enumerate(items):
for iitem, item in enumerate(derivs):
toret[iitem][ideriv][isample] = item
return toret, getter_inst, getter_size
def run(self, *args, **kwargs):
params = _params_args_or_kwargs(args, kwargs)
# Getter, or calculator, dict[param1, param2]
self.center = {}
# print(self.pipeline.input_values)
for param in self.all_params:
self.center[param.name] = params.get(param.name, self.pipeline.input_values[param.name])
if self.mpicomm.rank == 0:
samples = self._grid_samples.copy()
for param in self.all_params:
if param.name in self._grid_center:
offset = self.center[param.name] - self._grid_center[param.name]
samples[param] = self._grid_samples[param] + offset
else:
samples[param] = np.full(samples.shape, self.center[param.name])
nsamples = self.mpicomm.bcast(samples.size if self.mpicomm.rank == 0 else None, root=0)
getter_samples, getter_inst, getter_size = self._calculate(samples.to_dict(params=self.all_params) if self.mpicomm.rank == 0 else {})
toret = None
if self.mpicomm.rank == 0:
finiteparams, finiteorder, finiteaccuracy = [], [], []
for param in self._grid_samples.names():
if self.method[param] == 'finite':
finiteparams.append(param)
finiteorder.append(self.order[param])
finiteaccuracy.append(self.accuracy[param])
getter_samples = [[np.array(s) for s in getter_sample] for getter_sample in getter_samples]
#self.getter_samples = getter_samples
degrees, derivatives = [], [[] for i in range(max(getter_size, 1))]
cidx = self._grid_cidx
if finiteparams:
X = np.concatenate([samples[param].reshape(nsamples, 1) for param in finiteparams], axis=-1)
ndim = X.shape[1]
center = X[cidx]
autodegrees, autoindices = [Deriv()], [()]
for autoorder, autoderiv in enumerate(self.autoderivs):
nautodegrees, nautoindices = [], []
for autodegree, autoindex in zip(autodegrees, autoindices):
for iautoparam, autoparam in enumerate(autoderiv or (None,)):
if autoorder > 0:
nautodegree = autodegree + Deriv([autoparam])
nautoindex = autoindex + (iautoparam,)
else:
nautodegree = autodegree
nautoindex = autoindex
if nautodegree in degrees:
continue
nautodegrees.append(nautodegree)
nautoindices.append(nautoindex)
degrees.append(nautodegree)
Y = [getter_sample[autoorder][(slice(None),) + nautoindex + (Ellipsis,)] for getter_sample in getter_samples]
if autodegree: # with jax nan derivatives are zero derivatives...
for y in Y: y[np.isnan(y)] = 0.
for iy, y in enumerate(Y): derivatives[iy].append(y[cidx])
# Now finite differentiation
yshapes = [y.shape[samples.ndim:] for y in Y]
Y = [y.reshape(nsamples, -1) for y in Y]
for order in range(1, max(finiteorder + [0]) + 1):
for indices in itertools.product(range(ndim), repeat=order):
orders = np.bincount(indices, minlength=ndim).astype('i4')
if sum(orders) + autoorder > min(order for o, order in zip(orders, finiteorder) if o):
continue
degree = nautodegree + Deriv(dict(zip(finiteparams, orders)))
if degree in degrees:
continue
orders = [(iparam, order, accuracy) for iparam, (order, accuracy) in enumerate(zip(orders, finiteaccuracy)) if order > 0]
dx = [deriv_nd(X, y, orders, center=center, atol=0.) for y in Y]
if any(np.isnan(ddx).any() for ddx in dx):
raise ValueError('some derivatives are NaN')
degrees.append(degree)
for iy, (ddx, yshape) in enumerate(zip(dx, yshapes)): derivatives[iy].append(ddx.reshape(yshape))
autodegrees = nautodegrees
autoindices = nautoindices
toret = derivatives = [ParameterArray(derivative, derivs=degrees, param=Parameter('param_{:d}'.format(ideriv), shape=derivative[0].shape)) for ideriv, derivative in enumerate(derivatives)]
if isinstance(getter_inst, dict):
toret = Samples()
for param in self.varied_params:
toret[param] = ParameterArray(self.center[param.name], param=param)
for param, derivative in zip(getter_inst, derivatives):
derivative.param = Parameter(param)
toret[param] = derivative
toret.attrs['center'] = self.center
elif not getter_size:
toret = toret[0]
self.samples = toret
def __call__(self, *args, **kwargs):
"""
Return derivatives for input parameter values.
If ``getter`` returns a list (resp. dict), a list (resp. :class:`Samples`) of derivatives."""
self.run(*args, **kwargs)
return self.samples