Source code for firecrown.likelihood.statistic

"""Gaussian Family Statistic Module.

The Statistic class describing objects that implement methods to compute the
data and theory vectors for a :class:`GaussFamily` subclass.
"""

from __future__ import annotations

from abc import abstractmethod
from typing import final

import numpy as np
import numpy.typing as npt
import sacc

import firecrown.parameters
from firecrown.data_types import DataVector, TheoryVector
from firecrown.modeling_tools import ModelingTools
from firecrown.parameters import DerivedParameterCollection, RequiredParameters
from firecrown.updatable import Updatable


[docs] class StatisticUnreadError(RuntimeError): """Error raised when accessing an un-read statistic. Run-time error indicating an attempt has been made to use a statistic that has not had `read` called in it. """ def __init__(self, stat: Statistic): """Initialize a new StatisticUnreadError. :param stat: the statistic that was accessed before `read` was called """ msg = ( f"The statistic {stat} was used for calculation before `read` " f"was called.\nIt may be that a likelihood factory function did not" f"call `read` before returning the likelihood." ) super().__init__(msg) self.statistic = stat
[docs] class Statistic(Updatable): """The abstract base class for all physics-related statistics. Statistics read data from a SACC object as part of a multi-phase initialization. They manage a :class:`DataVector` and, given a :class:`ModelingTools` object, can compute a :class:`TheoryVector`. Statistics represent things like two-point functions and mass functions. """ def __init__(self, parameter_prefix: None | str = None): """Initialize a new Statistic. Derived classes should make sure to class this method using: .. code-block:: python super().__init__(parameter_prefix=parameter_prefix) as the first thing they do in `__init__`. :param parameter_prefix: The prefix to prepend to all parameter names """ super().__init__(parameter_prefix=parameter_prefix) self.sacc_indices: None | npt.NDArray[np.int64] self.ready = False self.computed_theory_vector = False self.theory_vector: None | TheoryVector = None
[docs] def read(self, _: sacc.Sacc) -> None: """Read the data for this statistic and mark it as ready for use. Derived classes that override this function should make sure to call the base class method using: .. code-block:: python super().read(sacc_data) as the last thing they do. :param _: currently unused, but required by the interface. """ assert len(self.get_data_vector()) > 0 self.ready = True
def _reset(self): """Reset this statistic. Derived classes that override this function should make sure to call the base class method using: .. code-block:: python super()._reset() as the last thing they do. """ self.computed_theory_vector = False self.theory_vector = None
[docs] @abstractmethod def get_data_vector(self) -> DataVector: """Gets the statistic data vector. :return: The data vector. """
[docs] @final def compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: """Compute a statistic from sources, applying any systematics. :param tools: the modeling tools used to compute the theory vector. :return: The computed theory vector. """ if not self.is_updated(): raise RuntimeError( f"The statistic {self} has not been updated with parameters." ) self.theory_vector = self._compute_theory_vector(tools) self.computed_theory_vector = True return self.theory_vector
@abstractmethod def _compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: """Compute a statistic from sources, concrete implementation."""
[docs] def get_theory_vector(self) -> TheoryVector: """Returns the last computed theory vector. Raises a RuntimeError if the vector has not been computed. :return: The already-computed theory vector. """ if not self.computed_theory_vector: raise RuntimeError( f"The theory for statistic {self} has not been computed yet." ) assert self.theory_vector is not None, ( "implementation error, " "computed_theory_vector is True but theory_vector is None" ) return self.theory_vector
[docs] class GuardedStatistic(Updatable): """An internal class used to maintain state on statistics. :class:`GuardedStatistic` is used by the framework to maintain and validate the state of instances of classes derived from :class:`Statistic`. """ def __init__(self, stat: Statistic): """Initialize the GuardedStatistic to contain the given :class:`Statistic`. :param stat: The statistic to wrap. """ super().__init__() assert isinstance(stat, Statistic) self.statistic = stat
[docs] def read(self, sacc_data: sacc.Sacc) -> None: """Read whatever data is needed from the given :class:`sacc.Sacc` object. After this function is called, the object should be prepared for the calling of the methods :meth:`get_data_vector` and :meth:`compute_theory_vector`. :param sacc_data: The SACC data object to read from. """ if self.statistic.ready: raise RuntimeError("Firecrown has called read twice on a GuardedStatistic") self.statistic.read(sacc_data)
[docs] def get_data_vector(self) -> DataVector: """Return the contained :class:`Statistic`'s data vector. :class:`GuardedStatistic` ensures that :meth:`read` has been called. first. :return: The most recently calculated data vector. """ if not self.statistic.ready: raise StatisticUnreadError(self.statistic) return self.statistic.get_data_vector()
[docs] def compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: """Return the contained :class:`Statistic`'s computed theory vector. :class:`GuardedStatistic` ensures that :meth:`read` has been called. first. :param tools: the modeling tools used to compute the theory vector. :return: The computed theory vector. """ if not self.statistic.ready: raise StatisticUnreadError(self.statistic) return self.statistic.compute_theory_vector(tools)
[docs] class TrivialStatistic(Statistic): """A minimal statistic only to be used for testing Gaussian likelihoods. It returns a :class:`DataVector` and :class:`TheoryVector` each of which is three elements long. The SACC data provided to :meth:`TrivialStatistic.read` must supply the necessary values. """ def __init__(self) -> None: """Initialize this statistic.""" super().__init__() # Data and theory will both be of length self.count self.count = 3 self.data_vector: None | DataVector = None self.mean = firecrown.parameters.register_new_updatable_parameter( default_value=0.0 ) self.computed_theory_vector = False
[docs] def read(self, sacc_data: sacc.Sacc) -> None: """Read the necessary items from the sacc data. :param sacc_data: The SACC data object to be read """ our_data = sacc_data.get_mean(data_type="count") assert len(our_data) == self.count self.data_vector = DataVector.from_list(our_data) self.sacc_indices = np.arange(len(self.data_vector)) super().read(sacc_data)
@final def _required_parameters(self) -> RequiredParameters: """Return an empty RequiredParameters. :return: an empty RequiredParameters. """ return RequiredParameters([]) @final def _get_derived_parameters(self) -> DerivedParameterCollection: """Return an empty DerivedParameterCollection. :return: an empty DerivedParameterCollection. """ return DerivedParameterCollection([])
[docs] def get_data_vector(self) -> DataVector: """Return the data vector; raise exception if there is none. :return: The data vector. """ assert self.data_vector is not None return self.data_vector
def _compute_theory_vector(self, _: ModelingTools) -> TheoryVector: """Return a fixed theory vector. :param _: unused, but required by the interface :return: A fixed theory vector """ return TheoryVector.from_list([self.mean] * self.count)