Source code for firecrown.likelihood.gauss_family.gauss_family

"""Gaussian Family Module.

Some notes.
"""

from __future__ import annotations

from enum import Enum
from functools import wraps
from typing import Optional, Sequence, Callable, Union, TypeVar
from typing import final
import warnings

from typing_extensions import ParamSpec
import numpy as np
import numpy.typing as npt
import scipy.linalg

import sacc

from ...parameters import ParamsMap
from ..likelihood import Likelihood
from ...modeling_tools import ModelingTools
from ...updatable import UpdatableCollection
from .statistic.statistic import Statistic, GuardedStatistic
from ...utils import save_to_sacc


[docs]class State(Enum): """The states used in GaussFamily.""" INITIALIZED = 1 READY = 2 UPDATED = 3 COMPUTED = 4
T = TypeVar("T") P = ParamSpec("P") # See https://peps.python.org/pep-0612/ and # https://stackoverflow.com/questions/66408662/type-annotations-for-decorators # for how to specify the types of *args and **kwargs, and the return type of # the method being decorated. # Beware
[docs]def enforce_states( *, initial: Union[State, list[State]], terminal: Optional[State] = None, failure_message: str, ) -> Callable[[Callable[P, T]], Callable[P, T]]: """This decorator wraps a method, and enforces state machine behavior. If the object is not in one of the states in initial, an AssertionError is raised with the given failure_message. If terminal is None the state of the object is not modified. If terminal is not None and the call to the wrapped method returns normally the state of the object is set to terminal. """ initials: list[State] if isinstance(initial, list): initials = initial else: initials = [initial] def decorator_enforce_states(func: Callable[P, T]) -> Callable[P, T]: """Part of the decorator which is the closure. This closure is what actually contains the values of initials, terminal, and failure_message. """ @wraps(func) def wrapper_repeat(*args: P.args, **kwargs: P.kwargs) -> T: """Part of the decorator which is the actual wrapped method. It is responsible for confirming a correct initial state, and establishing the correct final state if the wrapped method succeeds. """ # The odd use of args[0] instead of self seems to be the only way # to have both the Python runtime and mypy agree on what is being # passed to the method, and to allow access to the attribute # 'state'. Recall that the syntax: # o.foo() # calls a bound function object accessible as o.foo; this bound # function object calls the function foo() passing 'o' as the # first argument, self. assert isinstance(args[0], GaussFamily) assert args[0].state in initials, failure_message value = func(*args, **kwargs) if terminal is not None: args[0].state = terminal return value return wrapper_repeat return decorator_enforce_states
[docs]class GaussFamily(Likelihood): """GaussFamily is the base class for likelihoods based on a chi-squared calculation. It provides an implementation of Likelihood.compute_chisq. Derived classes must implement the abstract method compute_loglike, which is inherited from Likelihood. GaussFamily (and all classes that inherit from it) must abide by the the following rules regarding the order of calling of methods. 1. after a new object is created, :meth:`read` must be called before any other method in the interfaqce. 2. after :meth:`read` has been called it is legal to call :meth:`get_data_vector`, or to call :meth:`update`. 3. after :meth:`update` is called it is then legal to call :meth:`calculate_loglike` or :meth:`get_data_vector`, or to reset the object (returning to the pre-update state) by calling :meth:`reset`. It is also legal to call :meth:`compute_theory_vector`. 4. after :meth:`compute_theory_vector` is called it is legal to call :meth:`get_theory_vector` to retrieve the already-calculated theory vector. This state machine behavior is enforced through the use of the decorator :meth:`enforce_states`, above. """ def __init__( self, statistics: Sequence[Statistic], ): """Initialize the base class parts of a GaussFamily object.""" super().__init__() self.state: State = State.INITIALIZED if len(statistics) == 0: raise ValueError("GaussFamily requires at least one statistic") for i, s in enumerate(statistics): if not isinstance(s, Statistic): raise ValueError( f"statistics[{i}] is not an instance of Statistic: {s}" f"it is a {type(s)} instead." ) self.statistics: UpdatableCollection[GuardedStatistic] = UpdatableCollection( GuardedStatistic(s) for s in statistics ) self.cov: Optional[npt.NDArray[np.float64]] = None self.cholesky: Optional[npt.NDArray[np.float64]] = None self.inv_cov: Optional[npt.NDArray[np.float64]] = None self.cov_index_map: Optional[dict[int, int]] = None self.theory_vector: Optional[npt.NDArray[np.double]] = None self.data_vector: Optional[npt.NDArray[np.double]] = None
[docs] @enforce_states( initial=State.READY, terminal=State.UPDATED, failure_message="read() must be called before update()", ) def _update(self, _: ParamsMap) -> None: """Handle the state resetting required by :class:`GaussFamily` likelihoods. Any derived class that needs to implement :meth:`_update` for its own reasons must be sure to do what this does: check the state at the start of the method, and change the state at the end of the method. """
[docs] @enforce_states( initial=[State.UPDATED, State.COMPUTED], terminal=State.READY, failure_message="update() must be called before reset()", ) def _reset(self) -> None: """Handle the state resetting required by :class:`GaussFamily` likelihoods. Any derived class that needs to implement :meth:`reset` for its own reasons must be sure to do what this does: check the state at the start of the method, and change the state at the end of the method. """ self.theory_vector = None
[docs] @enforce_states( initial=State.INITIALIZED, terminal=State.READY, failure_message="read() must only be called once", ) def read(self, sacc_data: sacc.Sacc) -> None: """Read the covariance matrix for this likelihood from the SACC file.""" if sacc_data.covariance is None: msg = ( f"The {type(self).__name__} likelihood requires a covariance, " f"but the SACC data object being read does not have one." ) raise RuntimeError(msg) covariance = sacc_data.covariance.dense indices_list = [] data_vector_list = [] for stat in self.statistics: stat.read(sacc_data) if stat.statistic.sacc_indices is None: raise RuntimeError( f"The statistic {stat.statistic} has no sacc_indices." ) indices_list.append(stat.statistic.sacc_indices.copy()) data_vector_list.append(stat.statistic.get_data_vector()) indices = np.concatenate(indices_list) data_vector = np.concatenate(data_vector_list) cov = np.zeros((len(indices), len(indices))) for new_i, old_i in enumerate(indices): for new_j, old_j in enumerate(indices): cov[new_i, new_j] = covariance[old_i, old_j] self.data_vector = data_vector self.cov_index_map = {old_i: new_i for new_i, old_i in enumerate(indices)} self.cov = cov self.cholesky = scipy.linalg.cholesky(self.cov, lower=True) self.inv_cov = np.linalg.inv(cov)
[docs] @enforce_states( initial=[State.READY, State.UPDATED, State.COMPUTED], failure_message="read() must be called before get_cov()", ) @final def get_cov( self, statistic: Union[Statistic, list[Statistic], None] = None ) -> npt.NDArray[np.float64]: """Gets the current covariance matrix. :param statistic: The statistic for which the sub-covariance matrix should be return. If not specified, return the covariance of all statistics. """ assert self.cov is not None if statistic is None: return self.cov assert self.cov_index_map is not None if isinstance(statistic, Statistic): statistic_list = [statistic] else: statistic_list = statistic indices: list[int] = [] for stat in statistic_list: assert stat.sacc_indices is not None temp = [self.cov_index_map[idx] for idx in stat.sacc_indices] indices += temp ixgrid = np.ix_(indices, indices) return self.cov[ixgrid]
[docs] @final @enforce_states( initial=[State.READY, State.UPDATED, State.COMPUTED], failure_message="read() must be called before get_data_vector()", ) def get_data_vector(self) -> npt.NDArray[np.float64]: """Get the data vector from all statistics in the right order.""" assert self.data_vector is not None return self.data_vector
[docs] @final @enforce_states( initial=State.UPDATED, terminal=State.COMPUTED, failure_message="update() must be called before compute_theory_vector()", ) def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64]: """Computes the theory vector using the current instance of pyccl.Cosmology. :param tools: Current ModelingTools object """ theory_vector_list: list[npt.NDArray[np.float64]] = [ stat.compute_theory_vector(tools) for stat in self.statistics ] self.theory_vector = np.concatenate(theory_vector_list) return self.theory_vector
[docs] @final @enforce_states( initial=State.COMPUTED, failure_message="compute_theory_vector() must be called before " "get_theory_vector()", ) def get_theory_vector(self) -> npt.NDArray[np.float64]: """Get the theory vector from all statistics in the right order.""" assert ( self.theory_vector is not None ), "theory_vector is None after compute_theory_vector() has been called" return self.theory_vector
[docs] @final @enforce_states( initial=State.UPDATED, failure_message="update() must be called before compute()", ) def compute( self, tools: ModelingTools ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: """Calculate and return both the data and theory vectors.""" warnings.simplefilter("always", DeprecationWarning) warnings.warn( "The use of the `compute` method on Statistic is deprecated." "The Statistic objects should implement `get_data` and " "`compute_theory_vector` instead.", category=DeprecationWarning, ) return self.get_data_vector(), self.compute_theory_vector(tools)
[docs] @final @enforce_states( initial=[State.UPDATED, State.COMPUTED], terminal=State.COMPUTED, failure_message="update() must be called before compute_chisq()", ) def compute_chisq(self, tools: ModelingTools) -> float: """Calculate and return the chi-squared for the given cosmology.""" theory_vector: npt.NDArray[np.float64] data_vector: npt.NDArray[np.float64] residuals: npt.NDArray[np.float64] try: theory_vector = self.compute_theory_vector(tools) data_vector = self.get_data_vector() except NotImplementedError: data_vector, theory_vector = self.compute(tools) assert len(data_vector) == len(theory_vector) residuals = data_vector - theory_vector x = scipy.linalg.solve_triangular(self.cholesky, residuals, lower=True) chisq = np.dot(x, x) return chisq
[docs] @enforce_states( initial=[State.READY, State.UPDATED, State.COMPUTED], failure_message="read() must be called before get_sacc_indices()", ) def get_sacc_indices( self, statistic: Union[Statistic, list[Statistic], None] = None ) -> npt.NDArray[np.int64]: """Get the SACC indices of the statistic or list of statistics. If no statistic is given, get the indices of all statistics of the likelihood. """ if statistic is None: statistic = [stat.statistic for stat in self.statistics] if isinstance(statistic, Statistic): statistic = [statistic] sacc_indices_list = [] for stat in statistic: assert stat.sacc_indices is not None sacc_indices_list.append(stat.sacc_indices.copy()) sacc_indices = np.concatenate(sacc_indices_list) return sacc_indices
[docs] @enforce_states( initial=State.COMPUTED, failure_message="compute_theory_vector() must be called before " "make_realization()", ) def make_realization( self, sacc_data: sacc.Sacc, add_noise: bool = True, strict: bool = True ) -> sacc.Sacc: """Create a new realization of the model.""" sacc_indices = self.get_sacc_indices() if add_noise: new_data_vector = self.make_realization_vector() else: new_data_vector = self.get_theory_vector() new_sacc = save_to_sacc( sacc_data=sacc_data, data_vector=new_data_vector, indices=sacc_indices, strict=strict, ) return new_sacc