Source code for firecrown.likelihood.gauss_family.statistic.statistic

"""Gaussian Family Statistic Module.

The Statistic class describing objects that implement methods to compute the
data and theory vectors for a :class:`GaussFamily` subclass.
"""

from __future__ import annotations
from typing import Optional, final
from dataclasses import dataclass
from abc import abstractmethod
import warnings
import numpy as np
import numpy.typing as npt
import sacc

import firecrown.parameters
from firecrown.parameters import DerivedParameterCollection, RequiredParameters
from firecrown.modeling_tools import ModelingTools
from firecrown.updatable import Updatable


[docs]class DataVector(npt.NDArray[np.float64]): """This class wraps a np.ndarray that represents some observed data values."""
[docs] @classmethod def create(cls, vals: npt.NDArray[np.float64]) -> DataVector: """Create a DataVector that wraps a copy of the given array vals.""" return vals.view(cls)
[docs] @classmethod def from_list(cls, vals: list[float]) -> DataVector: """Create a DataVector from the given list of floats.""" array = np.array(vals) return cls.create(array)
[docs]class TheoryVector(npt.NDArray[np.float64]): """This class represents an observation predicted by some theory."""
[docs] @classmethod def create(cls, vals: npt.NDArray[np.float64]) -> TheoryVector: """Create a TheoryVector that wraps a copy of the given array vals.""" return vals.view(cls)
[docs] @classmethod def from_list(cls, vals: list[float]) -> TheoryVector: """Create a TheoryVector from the given list of floats.""" array = np.array(vals) return cls.create(array)
[docs]def residuals(data: DataVector, theory: TheoryVector) -> npt.NDArray[np.float64]: """Return a bare np.ndarray with the difference between `data` and `theory`. This is to be preferred to using arithmetic on the vectors directly. """ assert isinstance(data, DataVector) assert isinstance(theory, TheoryVector) return (data - theory).view(np.ndarray)
[docs]@dataclass class StatisticsResult: """This is the type returned by the `compute` method of any `Statistic`.""" data: DataVector theory: TheoryVector
[docs] def __post_init__(self): """Make sure the data and theory vectors are of the same shape.""" assert self.data.shape == self.theory.shape
[docs] def residuals(self) -> npt.NDArray[np.float64]: """Return the residuals -- the difference between data and theory.""" return self.data - self.theory
[docs] def __iter__(self): """Iterate through the data members. This is to allow automatic unpacking, as if the StatisticsResult were a tuple of (data, theory). This method is a temporary measure to help code migrate to the newer, safer interface for Statistic.compute(). """ warnings.warn( "Iteration and tuple unpacking for StatisticsResult is " "deprecated.\nPlease use the StatisticsResult class accessors" ".data and .theory by name." ) yield self.data yield self.theory
[docs]class StatisticUnreadError(RuntimeError): """Error raised when accessing an un-read statistic. Run-time error indicating an attempt has been made to use a statistic that has not had `read` called in it. """ def __init__(self, stat: Statistic): msg = ( f"The statistic {stat} was used for calculation before `read` " f"was called.\nIt may be that a likelihood factory function did not" f"call `read` before returning the likelihood." ) super().__init__(msg) self.statstic = stat
[docs]class Statistic(Updatable): """The abstract base class for all physics-related statistics. Statistics read data from a SACC object as part of a multi-phase initialization. They manage a :class:`DataVector` and, given a :class:`ModelingTools` object, can compute a :class:`TheoryVector`. Statistics represent things like two-point functions and mass functions. """ def __init__(self, parameter_prefix: Optional[str] = None): super().__init__(parameter_prefix=parameter_prefix) self.sacc_indices: Optional[npt.NDArray[np.int64]] self.ready = False self.computed_theory_vector = False self.theory_vector: Optional[TheoryVector] = None
[docs] def read(self, _: sacc.Sacc) -> None: """Read the data for this statistic and mark it as ready for use. Derived classes that override this function should make sure to call the base class method using: super().read(sacc_data) as the last thing they do in `__init__`. Note that currently the argument is not used; it is present so that this method will have the correct argument type for the override. """ self.ready = True if len(self.get_data_vector()) == 0: raise RuntimeError( f"the statistic {self} has read a data vector " f"of length 0; the length must be positive" )
[docs] def _reset(self): """Reset this statistic. All subclasses implementations must call super()._reset() """ self.computed_theory_vector = False self.theory_vector = None
[docs] @abstractmethod def get_data_vector(self) -> DataVector: """Gets the statistic data vector."""
[docs] @final def compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: """Compute a statistic from sources, applying any systematics.""" if not self.is_updated(): raise RuntimeError( f"The statistic {self} has not been updated with parameters." ) self.theory_vector = self._compute_theory_vector(tools) self.computed_theory_vector = True return self.theory_vector
[docs] @abstractmethod def _compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: """Compute a statistic from sources, concrete implementation."""
[docs] def get_theory_vector(self) -> TheoryVector: """Returns the last computed theory vector. Raises a RuntimeError if the vector has not been computed. """ if not self.computed_theory_vector: raise RuntimeError( f"The theory for statistic {self} has not been computed yet." ) assert self.theory_vector is not None, ( "implementation error, " "computed_theory_vector is True but theory_vector is None" ) return self.theory_vector
[docs]class GuardedStatistic(Updatable): """An internal class used to maintain state on statistics. :class:`GuardedStatistic` is used by the framework to maintain and validate the state of instances of classes derived from :class:`Statistic`. """ def __init__(self, stat: Statistic): """Initialize the GuardedStatistic to contain the given :class:`Statistic`.""" super().__init__() assert isinstance(stat, Statistic) self.statistic = stat
[docs] def read(self, sacc_data: sacc.Sacc) -> None: """Read whatever data is needed from the given :class:`sacc.Sacc` object. After this function is called, the object should be prepared for the calling of the methods :meth:`get_data_vector` and :meth:`compute_theory_vector`. """ if self.statistic.ready: raise RuntimeError("Firecrown has called read twice on a GuardedStatistic") try: self.statistic.read(sacc_data) except TypeError as exc: msg = ( f"A statistic of type {type(self.statistic).__name__} has raised " f"an exception during `read`.\nThe problem may be a malformed " f"SACC data object." ) raise RuntimeError(msg) from exc
[docs] def get_data_vector(self) -> DataVector: """Return the contained :class:`Statistic`'s data vector. :class:`GuardedStatistic` ensures that :meth:`read` has been called. first. """ if not self.statistic.ready: raise StatisticUnreadError(self.statistic) return self.statistic.get_data_vector()
[docs] def compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: """Return the contained :class:`Statistic`'s computed theory vector. :class:`GuardedStatistic` ensures that :meth:`read` has been called. first. """ if not self.statistic.ready: raise StatisticUnreadError(self.statistic) return self.statistic.compute_theory_vector(tools)
[docs]class TrivialStatistic(Statistic): """A minimal statistic only to be used for testing Gaussian likelihoods. It returns a :class:`DataVector` and :class:`TheoryVector` each of which is three elements long. The SACC data provided to :meth:`TrivialStatistic.read` must supply the necessary values. """ def __init__(self) -> None: """Initialize this statistic.""" super().__init__() # Data and theory will both be of length self.count self.count = 3 self.data_vector: Optional[DataVector] = None self.mean = firecrown.parameters.register_new_updatable_parameter() self.computed_theory_vector = False
[docs] def read(self, sacc_data: sacc.Sacc): """Read the necessary items from the sacc data.""" our_data = sacc_data.get_mean(data_type="count") assert len(our_data) == self.count self.data_vector = DataVector.from_list(our_data) self.sacc_indices = np.arange(len(self.data_vector)) super().read(sacc_data)
[docs] @final def _required_parameters(self) -> RequiredParameters: """Return an empty RequiredParameters.""" return RequiredParameters([])
[docs] @final def _get_derived_parameters(self) -> DerivedParameterCollection: """Return an empty DerivedParameterCollection.""" return DerivedParameterCollection([])
[docs] def get_data_vector(self) -> DataVector: """Return the data vector; raise exception if there is none.""" assert self.data_vector is not None return self.data_vector
[docs] def _compute_theory_vector(self, _: ModelingTools) -> TheoryVector: """Return a fixed theory vector.""" return TheoryVector.from_list([self.mean] * self.count)