Source code for firecrown.likelihood.two_point

"""Two point statistic support."""

from __future__ import annotations
import itertools
import warnings
from typing import Sequence

import numpy as np
import numpy.typing as npt
import pyccl
import pyccl.nl_pt
import sacc.windows


import firecrown.generators.two_point as gen
from firecrown.likelihood.source import Source, Tracer
from firecrown.likelihood.source_factories import (
    use_source_factory,
    use_source_factory_metadata_index,
)
from firecrown.likelihood.weak_lensing import (
    WeakLensingFactory,
)
from firecrown.likelihood.number_counts import (
    NumberCountsFactory,
)
from firecrown.likelihood.statistic import (
    Statistic,
)
from firecrown.metadata_types import (
    TracerNames,
    TwoPointHarmonic,
    TwoPointReal,
)

from firecrown.metadata_functions import (
    TwoPointHarmonicIndex,
    TwoPointRealIndex,
    extract_window_function,
    measurements_from_index,
)
from firecrown.data_types import TwoPointMeasurement, DataVector, TheoryVector
from firecrown.modeling_tools import ModelingTools
from firecrown.models.two_point import TwoPointTheory, calculate_pk
from firecrown.updatable import UpdatableCollection
from firecrown.utils import cached_angular_cl, make_log_interpolator
import firecrown.metadata_types as mdt

# only supported types are here, anything else will throw
# a value error


[docs] def calculate_angular_cl( ells: npt.NDArray[np.int64], pk_name: str, scale0: float, scale1: float, tools: ModelingTools, tracer0: Tracer, tracer1: Tracer, ): """Calculate the angular mulitpole moments. :param ells: The angular wavenumbers at which to compute the power spectrum. :param pk_name: The name of the power spectrum to return. :param scale0: The scale factor for the first tracer. :param scale1: The scale factor for the second tracer. :param tools: The modeling tools to use. :param tracer0: The first tracer to use. :param tracer1: The second tracer to use. :return: The angular mulitpole moments. """ pk = calculate_pk(pk_name, tools, tracer0, tracer1) result = ( cached_angular_cl( tools.get_ccl_cosmology(), (tracer0.ccl_tracer, tracer1.ccl_tracer), tuple(ells.ravel().tolist()), p_of_k_a=pk, ) * scale0 * scale1 ) return result
# pylint: disable=too-many-public-methods
[docs] class TwoPoint(Statistic): """A statistic that represents the correlation between two measurements. If the same source is used twice in the same TwoPoint object, this produces an autocorrelation. For example, shear correlation function, galaxy-shear correlation function, etc. Parameters ---------- sacc_data_type : str The kind of two-point statistic. This must be a valid SACC data type that maps to one of the CCL correlation function kinds or a power spectra. Possible options are - galaxy_density_cl : maps to 'cl' (a CCL angular power spectrum) - galaxy_density_xi : maps to 'gg' (a CCL angular position corr. function) - galaxy_shearDensity_cl_e : maps to 'cl' (a CCL angular power spectrum) - galaxy_shearDensity_xi_t : maps to 'gl' (a CCL angular cross-correlation between position and shear) - galaxy_shear_cl_ee : maps to 'cl' (a CCL angular power spectrum) - galaxy_shear_xi_minus : maps to 'l-' (a CCL angular shear corr. function xi-) - galaxy_shear_xi_plus : maps to 'l+' (a CCL angular shear corr. function xi-) - cmbGalaxy_convergenceDensity_xi : maps to 'gg' (a CCL angular position corr. function) - cmbGalaxy_convergenceShear_xi_t : maps to 'gl' (a CCL angular cross- correlation between position and shear) source0 : Source The first sources needed to compute this statistic. source1 : Source The second sources needed to compute this statistic. ell_or_theta : dict, optional A dictionary of options for generating the ell or theta values at which to compute the statistics. This option can be used to have firecrown generate data without the corresponding 2pt data in the input SACC file. The options are: - minimun : float - The start of the binning. - maximun : float - The end of the binning. - n : int - The number of bins. Note that the edges of the bins start at `min` and end at `max`. The actual bin locations will be at the (possibly geometric) midpoint of the bin. - binning : str, optional - Pass 'log' to get logarithmic spaced bins and 'lin' to get linearly spaced bins. Default is 'log'. ell_or_theta_min : float, optional The minimum ell or theta value to keep. This minimum is applied after the ell or theta values are read and/or generated. ell_or_theta_max : float, optional The maximum ell or theta value to keep. This maximum is applied after the ell or theta values are read and/or generated. ell_for_xi : dict, optional A dictionary of options for making the ell values at which to compute Cls for use in real-space integrations. The possible keys are: - minimum : int, optional - The minimum angular wavenumber to use for real-space integrations. Default is 2. - midpoint : int, optional - The midpoint angular wavenumber to use for real-space integrations. The angular wavenumber samples are linearly spaced at integers between `minimum` and `midpoint`. Default is 50. - maximum : int, optional - The maximum angular wavenumber to use for real-space integrations. The angular wavenumber samples are logarithmically spaced between `midpoint` and `maximum`. Default is 60,000. - n_log : int, optional - The number of logarithmically spaced angular wavenumber samples between `mid` and `max`. Default is 200. Attributes ---------- ccl_kind : str The CCL correlation function kind or 'cl' for power spectra corresponding to the SACC data type. sacc_tracers : 2-tuple of str A tuple of the SACC tracer names for this 2pt statistic. Set after a call to read. """ @property def sacc_data_type(self) -> str: """Backwards compatibility for sacc_data_type.""" return self.theory.sacc_data_type @property def source0(self) -> Source: """Backwards compatibility for source0.""" return self.theory.source0 @property def source1(self) -> Source: """Backwards compatibility for source1.""" return self.theory.source1 @property def window(self) -> None | npt.NDArray[np.float64]: """Backwards compatibility for window.""" return self.theory.window @property def sacc_tracers(self) -> None | TracerNames: """Backwards compatibility for sacc_tracers.""" return self.theory.sacc_tracers @property def ells(self) -> None | npt.NDArray[np.int64]: """Backwards compatibility for ells.""" return self.theory.ells @property def thetas(self) -> None | npt.NDArray[np.float64]: """Backwards compatibility for thetas.""" return self.theory.thetas @property def ells_for_xi(self) -> None | npt.NDArray[np.int64]: """Backwards compatibility for ells_for_xi.""" return self.theory.ells_for_xi @property def cells(self): """Backwards compatibility for cells.""" return self.theory.cells def __init__( self, sacc_data_type: str, source0: Source, source1: Source, *, ell_for_xi: None | dict[str, int] = None, ell_or_theta: None | gen.EllOrThetaConfig = None, ell_or_theta_min: None | float | int = None, ell_or_theta_max: None | float | int = None, tracers: None | TracerNames = None, ) -> None: super().__init__() self.theory = TwoPointTheory( sacc_data_type=sacc_data_type, sources=(source0, source1), ell_or_theta_min=ell_or_theta_min, ell_or_theta_max=ell_or_theta_max, ell_for_xi=ell_for_xi, ell_or_theta=ell_or_theta, tracers=tracers, ) self._data: None | DataVector = None
[docs] @classmethod def from_metadata_index( cls, metadata_indices: Sequence[TwoPointHarmonicIndex | TwoPointRealIndex], wl_factory: WeakLensingFactory | None = None, nc_factory: NumberCountsFactory | None = None, ) -> UpdatableCollection[TwoPoint]: """Create an UpdatableCollection of TwoPoint statistics. This constructor creates an UpdatableCollection of TwoPoint statistics from a list of TwoPointCellsIndex or TwoPointXiThetaIndex metadata index objects. The purpose of this constructor is to create a TwoPoint statistic from metadata index, which requires a follow-up call to `read` to read the data and metadata from the SACC object. :param metadata_index: The metadata index objects to initialize the TwoPoint statistics. :param wl_factory: The weak lensing factory to use. :param nc_factory: The number counts factory to use. :return: An UpdatableCollection of TwoPoint statistics. """ two_point_list = [ cls( sacc_data_type=metadata_index["data_type"], source0=use_source_factory_metadata_index( n1, a, wl_factory=wl_factory, nc_factory=nc_factory ), source1=use_source_factory_metadata_index( n2, b, wl_factory=wl_factory, nc_factory=nc_factory ), ) for metadata_index in metadata_indices for n1, a, n2, b in [measurements_from_index(metadata_index)] ] return UpdatableCollection(two_point_list)
@classmethod def _from_metadata_single( cls, *, metadata: TwoPointHarmonic | TwoPointReal, wl_factory: WeakLensingFactory | None = None, nc_factory: NumberCountsFactory | None = None, ) -> TwoPoint: """Create a single TwoPoint statistic from metadata. This constructor creates a single TwoPoint statistic from a TwoPointHarmonic or TwoPointReal metadata object. It requires the sources to be initialized before calling this constructor. The metadata object is used to initialize the TwoPoint statistic. No further calls to `read` are needed. """ match metadata: case TwoPointHarmonic(): two_point = cls._from_metadata_single_base( metadata, wl_factory, nc_factory ) two_point.theory.ells = metadata.ells two_point.theory.window = metadata.window case TwoPointReal(): two_point = cls._from_metadata_single_base( metadata, wl_factory, nc_factory ) two_point.theory.thetas = metadata.thetas two_point.theory.window = None two_point.theory.ells_for_xi = gen.log_linear_ells( **two_point.theory.ell_for_xi_config ) case _: raise ValueError(f"Metadata of type {type(metadata)} is not supported!") two_point.ready = True return two_point @classmethod def _from_metadata_single_base(cls, metadata, wl_factory, nc_factory): """Create a single TwoPoint statistic from metadata. Base method for creating a single TwoPoint statistic from metadata. :param metadata: The metadata object to initialize the TwoPoint statistic. :param wl_factory: The weak lensing factory to use. :param nc_factory: The number counts factory to use. :return: A TwoPoint statistic. """ source0 = use_source_factory( metadata.XY.x, metadata.XY.x_measurement, wl_factory=wl_factory, nc_factory=nc_factory, ) source1 = use_source_factory( metadata.XY.y, metadata.XY.y_measurement, wl_factory=wl_factory, nc_factory=nc_factory, ) two_point = cls( metadata.get_sacc_name(), source0, source1, tracers=metadata.XY.get_tracer_names(), ) return two_point
[docs] @classmethod def from_metadata( cls, metadata_seq: Sequence[TwoPointHarmonic | TwoPointReal], wl_factory: WeakLensingFactory | None = None, nc_factory: NumberCountsFactory | None = None, ) -> UpdatableCollection[TwoPoint]: """Create an UpdatableCollection of TwoPoint statistics from metadata. This constructor creates an UpdatableCollection of TwoPoint statistics from a list of TwoPointHarmonic or TwoPointReal metadata objects. The metadata objects are used to initialize the TwoPoint statistics. The sources are initialized using the factories provided. Note that TwoPoint created with this constructor are ready to be used, but contain no data. :param metadata_seq: The metadata objects to initialize the TwoPoint statistics. :param wl_factory: The weak lensing factory to use. :param nc_factory: The number counts factory to use. :return: An UpdatableCollection of TwoPoint statistics. """ two_point_list = [ cls._from_metadata_single( metadata=metadata, wl_factory=wl_factory, nc_factory=nc_factory ) for metadata in metadata_seq ] return UpdatableCollection(two_point_list)
[docs] @classmethod def create_two_point( cls, measurement: TwoPointMeasurement, wl_factory: None | WeakLensingFactory, nc_factory: None | NumberCountsFactory, ) -> TwoPoint: """Create a single TwoPoint statistic from a measurement. :param measurement: The measurement object to initialize the TwoPoint statistic. """ two_point = cls._from_metadata_single( metadata=measurement.metadata, wl_factory=wl_factory, nc_factory=nc_factory, ) two_point.sacc_indices = measurement.indices two_point.set_data_vector(DataVector.create(measurement.data)) two_point.ready = True return two_point
[docs] @classmethod def from_measurement( cls, measurements: Sequence[TwoPointMeasurement], wl_factory: WeakLensingFactory | None = None, nc_factory: NumberCountsFactory | None = None, ) -> UpdatableCollection[TwoPoint]: """Create an UpdatableCollection of TwoPoint statistics from measurements. This constructor creates an UpdatableCollection of TwoPoint statistics from a list of TwoPointMeasurement objects. The measurements are used to initialize the TwoPoint statistics. The sources are initialized using the factories provided. Note that TwoPoint created with this constructor are ready to be used and contain data. :param measurements: The measurements objects to initialize the TwoPoint statistics. :param wl_factory: The weak lensing factory to use. :param nc_factory: The number counts factory to use. :return: An UpdatableCollection of TwoPoint statistics. """ two_point_list: list[TwoPoint] = [ cls.create_two_point(m, wl_factory, nc_factory) for m in measurements ] return UpdatableCollection(two_point_list)
[docs] def read(self, sacc_data: sacc.Sacc) -> None: """Read the data for this statistic from the SACC file. :param sacc_data: The data in the sacc format. """ self.theory.initialize_sources(sacc_data) if self.theory.ccl_kind == "cl": self.read_harmonic_space(sacc_data) else: self.read_real_space(sacc_data) super().read(sacc_data)
[docs] def read_real_space(self, sacc_data: sacc.Sacc): """Read the data for this statistic from the SACC file.""" assert self.theory.sacc_tracers is not None thetas_xis_indices = read_reals(self.theory, sacc_data) # We do not support window functions for real space statistics if thetas_xis_indices is not None: thetas, xis, sacc_indices = thetas_xis_indices if self.theory.ell_or_theta_config is not None: # If we have data from our construction, and also have data in the # SACC object, emit a warning and use the information read from the # SACC object. warnings.warn( f"Tracers '{self.theory.sacc_tracers}' have 2pt data and you have " f"specified `theta` in the configuration. `theta` is being " f"ignored!", stacklevel=2, ) else: if self.theory.ell_or_theta_config is None: # The SACC file has no data points, just a tracer, in this case we # are building the statistic from scratch. In this case the user # must have set the dictionary ell_or_theta, containing the # minimum, maximum and number of bins to generate the ell values. raise RuntimeError( f"Tracers '{self.theory.sacc_tracers}' for data type " f"'{self.theory.sacc_data_type}' " "have no 2pt data in the SACC file and no input theta values " "were given!" ) thetas, xis = gen.generate_reals(self.theory.ell_or_theta_config) sacc_indices = None assert isinstance(self.theory.ell_or_theta_min, (float, type(None))) assert isinstance(self.theory.ell_or_theta_max, (float, type(None))) thetas, xis, sacc_indices = gen.apply_theta_min_max( thetas, xis, sacc_indices, self.theory.ell_or_theta_min, self.theory.ell_or_theta_max, ) self.theory.ells_for_xi = gen.log_linear_ells(**self.theory.ell_for_xi_config) self.theory.thetas = thetas self.sacc_indices = sacc_indices self._data = DataVector.create(xis)
[docs] def read_harmonic_space(self, sacc_data: sacc.Sacc) -> None: """Read the data for this statistic from the SACC file.""" assert self.theory.sacc_tracers is not None ells_cells_indices = read_ell_cells(self.theory, sacc_data) Cells, ells, sacc_indices, window = self.read_harmonic_spectrum_data( ells_cells_indices, sacc_data ) assert isinstance(self.theory.ell_or_theta_min, (int, type(None))) assert isinstance(self.theory.ell_or_theta_max, (int, type(None))) ells, Cells, sacc_indices = gen.apply_ells_min_max( ells, Cells, sacc_indices, self.theory.ell_or_theta_min, self.theory.ell_or_theta_max, ) self.theory.ells = ells if self.theory.ell_or_theta_min is not None: assert np.min(self.theory.ells) >= self.theory.ell_or_theta_min if self.theory.ell_or_theta_max is not None: assert np.max(self.theory.ells) <= self.theory.ell_or_theta_max self.theory.window = window self.sacc_indices = sacc_indices self._data = DataVector.create(Cells)
[docs] def read_harmonic_spectrum_data( self, ells_cells_indices: ( None | tuple[ npt.NDArray[np.int64], npt.NDArray[np.float64], npt.NDArray[np.int64] ] ), sacc_data: sacc.Sacc, ) -> tuple[ npt.NDArray[np.float64], npt.NDArray[np.int64], npt.NDArray[np.int64] | None, npt.NDArray[np.float64] | None, ]: """Read all the data for this statistic from the SACC file. :param ells_cells_indices: The ells, the cells and the indices of the data in the SACC file. :param sacc_data: The data in the sacc format. :return: The ells, the cells and the indices, and window function if there is one. """ if ells_cells_indices is not None: ells, Cells, sacc_indices = ells_cells_indices if self.theory.ell_or_theta_config is not None: # If we have data from our construction, and also have data in the # SACC object, emit a warning and use the information read from the # SACC object. warnings.warn( f"Tracers '{self.theory.sacc_tracers}' have 2pt data and you have " f"specified `ell` in the configuration. `ell` is being ignored!", stacklevel=2, ) replacement_ells: None | npt.NDArray[np.int64] window: None | npt.NDArray[np.float64] replacement_ells, window = extract_window_function(sacc_data, sacc_indices) if window is not None: # When using a window function, we do not calculate all Cl's. # For this reason we have a default set of ells that we use # to compute Cl's, and we have a set of ells used for # interpolation. assert replacement_ells is not None ells = replacement_ells else: if self.theory.ell_or_theta_config is None: # The SACC file has no data points, just a tracer, in this case we # are building the statistic from scratch. In this case the user # must have set the dictionary ell_or_theta, containing the # minimum, maximum and number of bins to generate the ell values. raise RuntimeError( f"Tracers '{self.theory.sacc_tracers}' for data type " f"'{self.theory.sacc_data_type}' " "have no 2pt data in the SACC file and no input ell values " "were given!" ) ells, Cells = gen.generate_ells_cells(self.theory.ell_or_theta_config) sacc_indices = None # When generating the ells and Cells we do not have a window function window = None return Cells, ells, sacc_indices, window
[docs] def get_data_vector(self) -> DataVector: """Return this statistic's data vector.""" assert self._data is not None return self._data
[docs] def set_data_vector(self, value: DataVector) -> None: """Set this statistic's data vector.""" assert value is not None self._data = value
[docs] def compute_theory_vector_real_space(self, tools: ModelingTools) -> TheoryVector: """Compute a two-point statistic in real space. This method computes the two-point statistic in real space. It first computes the Cl's in harmonic space and then translates them to real space using CCL. """ assert self.theory.ccl_kind != "cl" assert self.theory.thetas is not None assert self.theory.ells_for_xi is not None tracers0, scale0, tracers1, scale1 = self.theory.get_tracers_and_scales(tools) cells_for_xi = self.compute_cells( self.theory.ells_for_xi, scale0, scale1, tools, tracers0, tracers1 ) theory_vector = pyccl.correlation( tools.get_ccl_cosmology(), ell=self.theory.ells_for_xi, C_ell=cells_for_xi, theta=self.theory.thetas / 60, type=self.theory.ccl_kind, ) return TheoryVector.create(theory_vector)
[docs] def compute_theory_vector_harmonic_space( self, tools: ModelingTools ) -> TheoryVector: """Compute a two-point statistic in harmonic space. This method computes the two-point statistic in harmonic space. It computes either the Cl's at the ells provided by the SACC file or the ells required for the window function. """ assert self.theory.ccl_kind == "cl" assert self.theory.ells is not None tracers0, scale0, tracers1, scale1 = self.theory.get_tracers_and_scales(tools) if self.theory.window is not None: ells_for_interpolation = gen.calculate_ells_for_interpolation( self.theory.ells[0], self.theory.ells[-1] ) cells_interpolated = self.compute_cells_interpolated( self.theory.ells, ells_for_interpolation, scale0, scale1, tools, tracers0, tracers1, ) # Here we left multiply the computed Cl's by the window function to get the # final Cl's. theory_vector = np.einsum( "lb, l -> b", self.theory.window, cells_interpolated ) # We also compute the mean ell value associated with each bin. self.theory.mean_ells = np.einsum( "lb, l -> b", self.theory.window, self.theory.ells ) assert self._data is not None return TheoryVector.create(theory_vector) # If we get here, we are working in harmonic space without a window function. assert self.theory.ells is not None theory_vector = self.compute_cells( self.theory.ells, scale0, scale1, tools, tracers0, tracers1, ) return TheoryVector.create(theory_vector)
def _compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: """Compute a two-point statistic from sources.""" if self.theory.ccl_kind == "cl": return self.compute_theory_vector_harmonic_space(tools) return self.compute_theory_vector_real_space(tools)
[docs] def compute_cells( self, ells: npt.NDArray[np.int64], scale0: float, scale1: float, tools: ModelingTools, tracers0: Sequence[Tracer], tracers1: Sequence[Tracer], ) -> npt.NDArray[np.float64]: """Compute the power spectrum for the given ells and tracers.""" self.theory.cells = {} if tracers0 == tracers1: assert scale0 == scale1 # We should consider how to avoid doing the same calculation twice, # if possible. for tracer0, tracer1 in itertools.product(tracers0, tracers1): pk_name = f"{tracer0.field}:{tracer1.field}" tn = TracerNames(tracer0.tracer_name, tracer1.tracer_name) result = calculate_angular_cl( ells, pk_name, scale0, scale1, tools, tracer0, tracer1 ) self.theory.cells[tn] = result self.theory.cells[mdt.TRACER_NAMES_TOTAL] = np.array( sum(self.theory.cells.values()) ) theory_vector = self.theory.cells[mdt.TRACER_NAMES_TOTAL] return theory_vector
[docs] def compute_cells_interpolated( self, ells: npt.NDArray[np.int64], ells_for_interpolation: npt.NDArray[np.int64], scale0: float, scale1: float, tools: ModelingTools, tracers0: Sequence[Tracer], tracers1: Sequence[Tracer], ) -> npt.NDArray[np.float64]: """Compute the interpolated power spectrum for the given ells and tracers. :param ells: The angular wavenumbers at which to compute the power spectrum. :param ells_for_interpolation: The angular wavenumbers at which the power spectrum is computed for interpolation. :param scale0: The scale factor for the first tracer. :param scale1: The scale factor for the second tracer. :param tools: The modeling tools to use. :param tracers0: The first tracers to use. :param tracers1: The second tracers to use. Compute the power spectrum for the given ells and tracers and interpolate the result to the ells provided. :return: The interpolated power spectrum. """ computed_cells = self.compute_cells( ells_for_interpolation, scale0, scale1, tools, tracers0, tracers1 ) cell_interpolator = make_log_interpolator( ells_for_interpolation, computed_cells ) cell_interpolated = np.zeros(len(ells)) # We should not interpolate ell 0 and 1 ells_larger_than_1 = ells > 1 cell_interpolated[ells_larger_than_1] = cell_interpolator( ells[ells_larger_than_1] ) return cell_interpolated
[docs] def read_reals( theory: TwoPointTheory, sacc_data: sacc.Sacc, ) -> ( None | tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.int64]] ): """Read and return theta and xi. :param theory: The theory, carrying data type and tracers. :param sacc_data: The SACC data object to be read. :return: The theta and xi values. """ thetas, xis = sacc_data.get_theta_xi( theory.sacc_data_type, *theory.sacc_tracers, return_cov=False ) # As version 0.13 of sacc, the method get_real returns the # theta values and the xi values in arrays of the same length. assert len(thetas) == len(xis) common_length = len(thetas) if common_length == 0: return None sacc_indices = np.atleast_1d( sacc_data.indices(theory.sacc_data_type, theory.sacc_tracers) ) assert sacc_indices is not None # Needed for mypy assert len(sacc_indices) == common_length return thetas, xis, sacc_indices
[docs] def read_ell_cells( theory: TwoPointTheory, sacc_data: sacc.Sacc, ) -> ( None | tuple[npt.NDArray[np.int64], npt.NDArray[np.float64], npt.NDArray[np.int64]] ): """Read and return ell and Cell. :param theory: The theory, carrying data type and tracers. :param sacc_data: The SACC data object to be read. :return: The ell and Cell values. """ tracers = theory.sacc_tracers ells, cells = sacc_data.get_ell_cl( theory.sacc_data_type, *tracers, return_cov=False ) # As version 0.13 of sacc, the method get_ell_cl returns the # ell values and the Cl values in arrays of the same length. assert len(ells) == len(cells) common_length = len(ells) if common_length == 0: return None sacc_indices = np.atleast_1d(sacc_data.indices(theory.sacc_data_type, tracers)) assert sacc_indices is not None # Needed for mypy assert len(sacc_indices) == common_length return ells, cells, sacc_indices