"""Support for the family of Gaussian likelihoods."""
from __future__ import annotations
import warnings
from collections.abc import Callable, Sequence
from enum import Enum
from functools import wraps
from typing import TypeVar, final
import numpy as np
import numpy.typing as npt
import sacc
import scipy.linalg
from typing_extensions import ParamSpec
from firecrown.likelihood.likelihood import Likelihood
from firecrown.likelihood.statistic import (
GuardedStatistic,
Statistic,
)
from firecrown.modeling_tools import ModelingTools
from firecrown.parameters import ParamsMap
from firecrown.updatable import UpdatableCollection
from firecrown.utils import save_to_sacc
[docs]
class State(Enum):
"""The states used in GaussFamily.
GaussFamily and all subclasses enforce a state machine behavior based on these
states to ensure that the necessary initialization and setup is done in the correct
order.
"""
INITIALIZED = 1
READY = 2
UPDATED = 3
COMPUTED = 4
T = TypeVar("T")
P = ParamSpec("P")
# See https://peps.python.org/pep-0612/ and
# https://stackoverflow.com/questions/66408662/type-annotations-for-decorators
# for how to specify the types of *args and **kwargs, and the return type of
# the method being decorated.
# Beware
[docs]
def enforce_states(
*,
initial: State | list[State],
terminal: None | State = None,
failure_message: str,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""This decorator wraps a method, and enforces state machine behavior.
If the object is not in one of the states in initial, an
AssertionError is raised with the given failure_message.
If terminal is None the state of the object is not modified.
If terminal is not None and the call to the wrapped method returns
normally the state of the object is set to terminal.
:param initial: The initial states allowable for the wrapped method
:param terminal: The terminal state ensured for the wrapped method. None
indicates no state change happens.
:param failure_message: The failure message for the AssertionError raised
:return: The wrapped method
"""
initials: list[State]
if isinstance(initial, list):
initials = initial
else:
initials = [initial]
def decorator_enforce_states(func: Callable[P, T]) -> Callable[P, T]:
"""Part of the decorator which is the closure.
This closure is what actually contains the values of initials, terminal, and
failure_message.
:param func: The method to be wrapped
:return: The wrapped method
"""
@wraps(func)
def wrapper_repeat(*args: P.args, **kwargs: P.kwargs) -> T:
"""Part of the decorator which is the actual wrapped method.
It is responsible for confirming a correct initial state, and
establishing the correct final state if the wrapped method
succeeds.
"""
# The odd use of args[0] instead of self seems to be the only way
# to have both the Python runtime and mypy agree on what is being
# passed to the method, and to allow access to the attribute
# 'state'. Recall that the syntax:
# o.foo()
# calls a bound function object accessible as o.foo; this bound
# function object calls the function foo() passing 'o' as the
# first argument, self.
assert isinstance(args[0], GaussFamily)
assert args[0].state in initials, failure_message
value = func(*args, **kwargs)
if terminal is not None:
args[0].state = terminal
return value
return wrapper_repeat
return decorator_enforce_states
[docs]
class GaussFamily(Likelihood):
"""GaussFamily is the base class for likelihoods based on a chi-squared calculation.
It provides an implementation of Likelihood.compute_chisq. Derived classes must
implement the abstract method compute_loglike, which is inherited from Likelihood.
GaussFamily (and all classes that inherit from it) must abide by the the
following rules regarding the order of calling of methods.
1. after a new object is created, :meth:`read` must be called before any
other method in the interface.
2. after :meth:`read` has been called it is legal to call
:meth:`get_data_vector`, or to call :meth:`update`.
3. after :meth:`update` is called it is then legal to call
:meth:`calculate_loglike` or :meth:`get_data_vector`, or to reset
the object (returning to the pre-update state) by calling
:meth:`reset`. It is also legal to call :meth:`compute_theory_vector`.
4. after :meth:`compute_theory_vector` is called it is legal to call
:meth:`get_theory_vector` to retrieve the already-calculated theory
vector.
This state machine behavior is enforced through the use of the decorator
:meth:`enforce_states`, above.
"""
def __init__(
self,
statistics: Sequence[Statistic],
) -> None:
"""Initialize the base class parts of a GaussFamily object.
:param statistics: A list of statistics to be include in chisquared calculations
"""
super().__init__()
self.state: State = State.INITIALIZED
if len(statistics) == 0:
raise ValueError("GaussFamily requires at least one statistic")
for i, s in enumerate(statistics):
if not isinstance(s, Statistic):
raise ValueError(
f"statistics[{i}] is not an instance of Statistic."
f" It is a {type(s)}."
)
self.statistics: UpdatableCollection[GuardedStatistic] = UpdatableCollection(
GuardedStatistic(s) for s in statistics
)
self.cov: None | npt.NDArray[np.float64] = None
self.cholesky: None | npt.NDArray[np.float64] = None
self.inv_cov: None | npt.NDArray[np.float64] = None
self.cov_index_map: None | dict[int, int] = None
self.theory_vector: None | npt.NDArray[np.double] = None
self.data_vector: None | npt.NDArray[np.double] = None
[docs]
@classmethod
def create_ready(
cls, statistics: Sequence[Statistic], covariance: npt.NDArray[np.float64]
) -> GaussFamily:
"""Create a GaussFamily object in the READY state.
:param statistics: A list of statistics to be include in chisquared calculations
:param covariance: The covariance matrix of the statistics
:return: A ready GaussFamily object
"""
obj = cls(statistics)
obj._set_covariance(covariance)
obj.state = State.READY
return obj
@enforce_states(
initial=State.READY,
terminal=State.UPDATED,
failure_message="read() must be called before update()",
)
def _update(self, _: ParamsMap) -> None:
"""Handle the state resetting required by :class:`GaussFamily` likelihoods.
Any derived class that needs to implement :meth:`_update`
for its own reasons must be sure to do what this does: check the state
at the start of the method, and change the state at the end of the
method.
:param _: a ParamsMap object, not used
"""
@enforce_states(
initial=[State.UPDATED, State.COMPUTED],
terminal=State.READY,
failure_message="update() must be called before reset()",
)
def _reset(self) -> None:
"""Handle the state resetting required by :class:`GaussFamily` likelihoods.
Any derived class that needs to implement :meth:`reset`
for its own reasons must be sure to do what this does: check the state
at the start of the method, and change the state at the end of the
method.
"""
self.theory_vector = None
[docs]
@enforce_states(
initial=State.INITIALIZED,
terminal=State.READY,
failure_message="read() must only be called once",
)
def read(self, sacc_data: sacc.Sacc) -> None:
"""Read the covariance matrix for this likelihood from the SACC file.
:param sacc_data: The SACC data object to be read
"""
if sacc_data.covariance is None:
msg = (
f"The {type(self).__name__} likelihood requires a covariance, "
f"but the SACC data object being read does not have one."
)
raise RuntimeError(msg)
for stat in self.statistics:
stat.read(sacc_data)
covariance = sacc_data.covariance.dense
self._set_covariance(covariance)
def _set_covariance(self, covariance: npt.NDArray[np.float64]) -> None:
"""Set the covariance matrix.
This method is used to set the covariance matrix and perform the
necessary calculations to prepare the likelihood for computation.
:param covariance: The covariance matrix for this likelihood
"""
indices_list = []
data_vector_list = []
for stat in self.statistics:
if not stat.statistic.ready:
raise RuntimeError(
f"The statistic {stat.statistic} is not ready to be used."
)
if stat.statistic.sacc_indices is None:
raise RuntimeError(
f"The statistic {stat.statistic} has no sacc_indices."
)
indices_list.append(stat.statistic.sacc_indices.copy())
data_vector_list.append(stat.statistic.get_data_vector())
indices = np.concatenate(indices_list).astype(int)
data_vector = np.concatenate(data_vector_list)
cov = np.zeros((len(indices), len(indices)))
largest_index = int(np.max(indices))
if not (
covariance.ndim == 2
and covariance.shape[0] == covariance.shape[1]
and largest_index < covariance.shape[0]
):
raise ValueError(
f"The covariance matrix has shape {covariance.shape}, "
f"but the expected shape is at least "
f"{(largest_index + 1, largest_index + 1)}."
)
for new_i, old_i in enumerate(indices):
for new_j, old_j in enumerate(indices):
cov[new_i, new_j] = covariance[old_i, old_j]
self.data_vector = data_vector
self.cov_index_map = {old_i: new_i for new_i, old_i in enumerate(indices)}
self.cov = cov
self.cholesky = scipy.linalg.cholesky(self.cov, lower=True).astype(np.float64)
self.inv_cov = np.linalg.inv(cov).astype(np.float64)
[docs]
@final
@enforce_states(
initial=[State.READY, State.UPDATED, State.COMPUTED],
failure_message="read() must be called before get_cov()",
)
def get_cov(
self, statistic: Statistic | list[Statistic] | None = None
) -> npt.NDArray[np.float64]:
"""Gets the current covariance matrix.
:param statistic: The statistic for which the sub-covariance matrix
should be returned. If not specified, return the covariance of all
statistics.
:return: The covariance matrix (or portion thereof)
"""
assert self.cov is not None
if statistic is None:
return self.cov
assert self.cov_index_map is not None
if isinstance(statistic, Statistic):
statistic_list = [statistic]
else:
statistic_list = statistic
indices: list[int] = []
for stat in statistic_list:
assert stat.sacc_indices is not None
temp = [self.cov_index_map[int(idx)] for idx in stat.sacc_indices]
indices += temp
ixgrid = np.ix_(indices, indices)
return self.cov[ixgrid]
[docs]
@final
@enforce_states(
initial=[State.READY, State.UPDATED, State.COMPUTED],
failure_message="read() must be called before get_data_vector()",
)
def get_data_vector(self) -> npt.NDArray[np.float64]:
"""Get the data vector from all statistics in the right order.
:return: The data vector
"""
assert self.data_vector is not None
return self.data_vector.astype(np.float64)
[docs]
@final
@enforce_states(
initial=[State.UPDATED, State.COMPUTED],
terminal=State.COMPUTED,
failure_message="update() must be called before compute_theory_vector()",
)
def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64]:
"""Computes the theory vector using the current instance of pyccl.Cosmology.
:param tools: Current ModelingTools object
:return: The computed theory vector
"""
theory_vector_list: list[npt.NDArray[np.float64]] = [
stat.compute_theory_vector(tools) for stat in self.statistics
]
self.theory_vector = np.concatenate(theory_vector_list)
return self.theory_vector
[docs]
@final
@enforce_states(
initial=State.COMPUTED,
failure_message="compute_theory_vector() must be called before "
"get_theory_vector()",
)
def get_theory_vector(self) -> npt.NDArray[np.float64]:
"""Get the already-computed theory vector from all statistics.
:return: The theory vector, with all statistics in the right order
"""
assert (
self.theory_vector is not None
), "theory_vector is None after compute_theory_vector() has been called"
return self.theory_vector.astype(np.float64)
[docs]
@final
@enforce_states(
initial=State.UPDATED,
failure_message="update() must be called before compute()",
)
def compute(
self, tools: ModelingTools
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
"""Calculate and return both the data and theory vectors.
This method is deprecated and will be removed in a future version of Firecrown.
:param tools: the ModelingTools to be used in the calculation of the
theory vector
:return: a tuple containing the data vector and the theory vector
"""
warnings.warn(
"The use of the `compute` method on Statistic is deprecated."
"The Statistic objects should implement `get_data` and "
"`compute_theory_vector` instead.",
category=DeprecationWarning,
)
return self.get_data_vector(), self.compute_theory_vector(tools)
[docs]
@final
@enforce_states(
initial=[State.UPDATED, State.COMPUTED],
terminal=State.COMPUTED,
failure_message="update() must be called before compute_chisq()",
)
def compute_chisq(self, tools: ModelingTools) -> float:
"""Calculate and return the chi-squared for the given cosmology.
:param tools: the ModelingTools to be used in the calculation of the
theory vector
:return: the chi-squared
"""
theory_vector: npt.NDArray[np.float64]
data_vector: npt.NDArray[np.float64]
residuals: npt.NDArray[np.float64]
try:
theory_vector = self.compute_theory_vector(tools)
data_vector = self.get_data_vector()
except NotImplementedError:
data_vector, theory_vector = self.compute(tools)
assert len(data_vector) == len(theory_vector)
residuals = np.array(data_vector - theory_vector, dtype=np.float64)
assert self.cholesky is not None
x = scipy.linalg.solve_triangular(self.cholesky, residuals, lower=True)
chisq = np.dot(x, x)
return chisq
[docs]
@enforce_states(
initial=[State.READY, State.UPDATED, State.COMPUTED],
failure_message="read() must be called before get_sacc_indices()",
)
def get_sacc_indices(
self, statistic: Statistic | list[Statistic] | None = None
) -> npt.NDArray[np.int64]:
"""Get the SACC indices of the statistic or list of statistics.
If no statistic is given, get the indices of all statistics of the likelihood.
:param statistics: The statistic or list of statistics for which the
SACC indices are desired
:return: The SACC indices
"""
if statistic is None:
statistic = [stat.statistic for stat in self.statistics]
if isinstance(statistic, Statistic):
statistic = [statistic]
sacc_indices_list = []
for stat in statistic:
assert stat.sacc_indices is not None
sacc_indices_list.append(stat.sacc_indices.copy())
sacc_indices = np.concatenate(sacc_indices_list)
return sacc_indices
[docs]
@enforce_states(
initial=State.COMPUTED,
failure_message="compute_theory_vector() must be called before "
"make_realization()",
)
def make_realization(
self, sacc_data: sacc.Sacc, add_noise: bool = True, strict: bool = True
) -> sacc.Sacc:
"""Create a new realization of the model.
:param sacc_data: The SACC data object containing the covariance matrix
to be read
:param add_noise: If True, add noise to the realization.
:param strict: If True, check that the indices of the realization cover
all the indices of the SACC data object.
:return: The SACC data object containing the new realization
"""
sacc_indices = self.get_sacc_indices()
if add_noise:
new_data_vector = self.make_realization_vector()
else:
new_data_vector = self.get_theory_vector()
new_sacc = save_to_sacc(
sacc_data=sacc_data,
data_vector=new_data_vector,
indices=sacc_indices,
strict=strict,
)
return new_sacc