"""Two point statistic support."""
from __future__ import annotations
import itertools
import warnings
from collections.abc import Sequence
from typing import Annotated
import numpy as np
import numpy.typing as npt
import pyccl
import sacc
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
PrivateAttr,
)
import firecrown.generators.two_point as gen
import firecrown.metadata_types as mdt
from firecrown.data_types import DataVector, TheoryVector, TwoPointMeasurement
from firecrown.likelihood.cmb import CMBConvergence, CMBConvergenceFactory
from firecrown.likelihood.number_counts import NumberCounts, NumberCountsFactory
from firecrown.likelihood.source import Source, Tracer
from firecrown.likelihood.statistic import Statistic
from firecrown.likelihood.weak_lensing import WeakLensing, WeakLensingFactory
from firecrown.metadata_functions import (
TwoPointHarmonicIndex,
TwoPointRealIndex,
extract_window_function,
make_correlation_space,
measurements_from_index,
)
from firecrown.metadata_types import (
CMB_TYPES,
GALAXY_LENS_TYPES,
GALAXY_SOURCE_TYPES,
InferredGalaxyZDist,
Measurement,
TracerNames,
TwoPointCorrelationSpace,
TwoPointHarmonic,
TwoPointReal,
TypeSource,
)
from firecrown.modeling_tools import ModelingTools
from firecrown.models.two_point import (
ApplyInterpolationWhen,
TwoPointTheory,
calculate_pk,
)
from firecrown.updatable import UpdatableCollection
from firecrown.utils import (
ClIntegrationOptions,
cached_angular_cl,
make_log_interpolator,
)
# only supported types are here, anything else will throw
# a value error
[docs]
def calculate_angular_cl(
ells: npt.NDArray[np.int64],
pk_name: str,
scale0: float,
scale1: float,
tools: ModelingTools,
tracer0: Tracer,
tracer1: Tracer,
int_options: ClIntegrationOptions | None = None,
):
"""Calculate the angular multipole moments.
:param ells: The angular wavenumbers at which to compute the power spectrum.
:param pk_name: The name of the power spectrum to return.
:param scale0: The scale factor for the first tracer.
:param scale1: The scale factor for the second tracer.
:param tools: The modeling tools to use.
:param tracer0: The first tracer to use.
:param tracer1: The second tracer to use.
:return: The angular multipole moments.
"""
pk = calculate_pk(pk_name, tools, tracer0, tracer1)
cosmo_in = tools.get_ccl_cosmology()
result = (
cached_angular_cl(
tools.get_ccl_cosmology(),
(tracer0.ccl_tracer, tracer1.ccl_tracer),
tuple(ells.ravel().tolist()),
p_of_k_a=pk,
p_of_k_a_lin=cosmo_in.get_linear_power(),
int_options=int_options,
)
* scale0
* scale1
)
return result
# pylint: disable=too-many-public-methods
[docs]
class TwoPoint(Statistic):
"""A two-point statistic.
A two-point statistic represents the correlation between two measurements.
If the same source is used twice, this produces an autocorrelation.
This class supports various two-point statistics including shear correlation
functions, galaxy-shear correlation functions, and galaxy clustering
statistics in both harmonic and real space.
The `sacc_data_type` parameter specifies the type of two-point statistic.
Valid SACC data types map to CCL correlation function types or power spectra:
* `galaxy_density_cl`: CCL angular power spectrum (cl)
* `galaxy_density_xi`: CCL angular position correlation (gg)
* `galaxy_shearDensity_cl_e`: CCL angular power spectrum (cl)
* `galaxy_shearDensity_xi_t`: CCL position-shear cross-correlation (gl)
* `galaxy_shear_cl_ee`: CCL angular power spectrum (cl)
* `galaxy_shear_xi_minus`: CCL angular shear correlation xi-
* `galaxy_shear_xi_plus`: CCL angular shear correlation xi+
* `cmbGalaxy_convergenceDensity_xi`: CCL position correlation (gg)
* `cmbGalaxy_convergenceShear_xi_t`: CCL position-shear cross-correlation (gl)
The `ell_or_theta` parameter allows generating ell or theta values for
computing statistics when the corresponding data is not present in the SACC
file. It accepts a dictionary with keys: `minimum` (float), `maximum`
(float), `n` (int), and `binning` (str, 'log' or 'lin').
The `ell_for_xi` parameter configures ell values for computing power spectra
used in real-space integrations. It accepts a dictionary with keys: `minimum`
(int, default 2), `midpoint` (int, default 50), `maximum` (int, default
60000), and `n_log` (int, default 200).
:ivar ccl_kind: The CCL correlation function kind or 'cl' for power spectra.
:ivar sacc_tracers: The SACC tracer names for this statistic, set after read.
"""
[docs]
@property
def sacc_data_type(self) -> str:
"""Backwards compatibility for sacc_data_type."""
return self.theory.sacc_data_type
[docs]
@property
def source0(self) -> Source:
"""Backwards compatibility for source0."""
return self.theory.source0
[docs]
@property
def source1(self) -> Source:
"""Backwards compatibility for source1."""
return self.theory.source1
[docs]
@property
def window(self) -> None | npt.NDArray[np.float64]:
"""Backwards compatibility for window."""
return self.theory.window
[docs]
@property
def sacc_tracers(self) -> None | TracerNames:
"""Backwards compatibility for sacc_tracers."""
return self.theory.sacc_tracers
[docs]
@property
def ells(self) -> None | npt.NDArray[np.int64]:
"""Backwards compatibility for ells."""
return self.theory.ells
[docs]
@property
def thetas(self) -> None | npt.NDArray[np.float64]:
"""Backwards compatibility for thetas."""
return self.theory.thetas
[docs]
@property
def ells_for_xi(self) -> None | npt.NDArray[np.int64]:
"""Backwards compatibility for ells_for_xi."""
return self.theory.ells_for_xi
[docs]
@property
def cells(self):
"""Backwards compatibility for cells."""
return self.theory.cells
def __init__(
self,
sacc_data_type: str,
source0: Source,
source1: Source,
*,
interp_ells_gen: gen.LogLinearElls = gen.LogLinearElls(),
ell_or_theta: None | gen.EllOrThetaConfig = None,
tracers: None | TracerNames = None,
int_options: ClIntegrationOptions | None = None,
apply_interp: ApplyInterpolationWhen = ApplyInterpolationWhen.DEFAULT,
) -> None:
super().__init__()
self.theory = TwoPointTheory(
sacc_data_type=sacc_data_type,
sources=(source0, source1),
interp_ells_gen=interp_ells_gen,
ell_or_theta=ell_or_theta,
tracers=tracers,
int_options=int_options,
apply_interp=apply_interp,
)
self._data: None | DataVector = None
@classmethod
def _from_metadata_single(
cls, metadata: TwoPointHarmonic | TwoPointReal, tp_factory: TwoPointFactory
) -> TwoPoint:
"""Create a single TwoPoint statistic from metadata.
This constructor creates a single TwoPoint statistic from a TwoPointHarmonic or
TwoPointReal metadata object. It requires the sources to be initialized before
calling this constructor. The metadata object is used to initialize the TwoPoint
statistic. No further calls to `read` are needed.
"""
match metadata:
case TwoPointHarmonic():
two_point = cls._from_metadata_single_base(metadata, tp_factory)
two_point.theory.ells = metadata.ells
two_point.theory.window = metadata.window
case TwoPointReal():
two_point = cls._from_metadata_single_base(metadata, tp_factory)
two_point.theory.thetas = metadata.thetas
two_point.theory.window = None
case _:
raise ValueError(f"Metadata of type {type(metadata)} is not supported!")
two_point.ready = True
return two_point
@classmethod
def _from_metadata_single_base(
cls, metadata: TwoPointHarmonic | TwoPointReal, tp_factory: TwoPointFactory
):
"""Create a single TwoPoint statistic from metadata.
Base method for creating a single TwoPoint statistic from metadata.
:param metadata: The metadata object to initialize the TwoPoint statistic.
:param wl_factory: The weak lensing factory to use.
:param nc_factory: The number counts factory to use.
:return: A TwoPoint statistic.
"""
source0 = use_source_factory(
metadata.XY.x, metadata.XY.x_measurement, tp_factory
)
source1 = use_source_factory(
metadata.XY.y, metadata.XY.y_measurement, tp_factory
)
two_point = cls(
metadata.get_sacc_name(),
source0,
source1,
tracers=metadata.XY.get_tracer_names(),
int_options=tp_factory.int_options,
)
return two_point
[docs]
@classmethod
def create_two_point(
cls, measurement: TwoPointMeasurement, tp_factory: TwoPointFactory
) -> TwoPoint:
"""Create a single TwoPoint statistic from a measurement.
:param measurement: The measurement object to initialize the TwoPoint statistic.
"""
two_point = cls._from_metadata_single(measurement.metadata, tp_factory)
two_point.sacc_indices = measurement.indices
two_point.set_data_vector(DataVector.create(measurement.data))
two_point.ready = True
return two_point
[docs]
@classmethod
def from_measurement(
cls,
measurements: Sequence[TwoPointMeasurement],
tp_factory: TwoPointFactory,
) -> UpdatableCollection[TwoPoint]:
"""Create an UpdatableCollection of TwoPoint statistics from measurements.
This constructor creates an UpdatableCollection of TwoPoint statistics from a
list of TwoPointMeasurement objects. The measurements are used to initialize the
TwoPoint statistics. The sources are initialized using the factories provided.
Note that TwoPoint created with this constructor are ready to be used and
contain data.
:param measurements: The measurements objects to initialize the TwoPoint
statistics.
:param wl_factory: The weak lensing factory to use.
:param nc_factory: The number counts factory to use.
:return: An UpdatableCollection of TwoPoint statistics.
"""
two_point_list: list[TwoPoint] = [
cls.create_two_point(m, tp_factory) for m in measurements
]
return UpdatableCollection(two_point_list)
[docs]
def read(self, sacc_data: sacc.Sacc) -> None:
"""Read the data for this statistic from the SACC file.
:param sacc_data: The data in the sacc format.
"""
self.theory.initialize_sources(sacc_data)
if self.theory.ccl_kind == "cl":
self.read_harmonic_space(sacc_data)
else:
self.read_real_space(sacc_data)
super().read(sacc_data)
[docs]
def read_real_space(self, sacc_data: sacc.Sacc):
"""Read the data for this statistic from the SACC file."""
assert self.theory.sacc_tracers is not None
thetas_xis_indices = read_reals(self.theory, sacc_data)
# We do not support window functions for real space statistics
if thetas_xis_indices is not None:
thetas, xis, sacc_indices = thetas_xis_indices
if self.theory.ell_or_theta_config is not None:
# If we have data from our construction, and also have data in the
# SACC object, emit a warning and use the information read from the
# SACC object.
warnings.warn(
f"Tracers '{self.theory.sacc_tracers}' have 2pt data and you have "
f"specified `theta` in the configuration. `theta` is being "
f"ignored!",
stacklevel=2,
)
else:
if self.theory.ell_or_theta_config is None:
# The SACC file has no data points, just a tracer, in this case we
# are building the statistic from scratch. In this case the user
# must have set the dictionary ell_or_theta, containing the
# minimum, maximum and number of bins to generate the ell values.
raise RuntimeError(
f"Tracers '{self.theory.sacc_tracers}' for data type "
f"'{self.theory.sacc_data_type}' "
"have no 2pt data in the SACC file and no input theta values "
"were given!"
)
thetas, xis = gen.generate_reals(self.theory.ell_or_theta_config)
sacc_indices = None
self.theory.thetas = thetas
self.sacc_indices = sacc_indices
self._data = DataVector.create(xis)
[docs]
def read_harmonic_space(self, sacc_data: sacc.Sacc) -> None:
"""Read the data for this statistic from the SACC file."""
assert self.theory.sacc_tracers is not None
ells_cells_indices = read_ell_cells(self.theory, sacc_data)
Cells, ells, sacc_indices, window = self.read_harmonic_spectrum_data(
ells_cells_indices, sacc_data
)
self.theory.ells = ells
self.theory.window = window
self.sacc_indices = sacc_indices
self._data = DataVector.create(Cells)
[docs]
def read_harmonic_spectrum_data(
self,
ells_cells_indices: (
None
| tuple[
npt.NDArray[np.int64], npt.NDArray[np.float64], npt.NDArray[np.int64]
]
),
sacc_data: sacc.Sacc,
) -> tuple[
npt.NDArray[np.float64],
npt.NDArray[np.int64],
npt.NDArray[np.int64] | None,
npt.NDArray[np.float64] | None,
]:
"""Read all the data for this statistic from the SACC file.
:param ells_cells_indices: The ells, the cells and the indices of the
data in the SACC file.
:param sacc_data: The data in the sacc format.
:return: The ells, the cells and the indices, and window function if
there is one.
"""
if ells_cells_indices is not None:
ells, Cells, sacc_indices = ells_cells_indices
if self.theory.ell_or_theta_config is not None:
# If we have data from our construction, and also have data in the
# SACC object, emit a warning and use the information read from the
# SACC object.
warnings.warn(
f"Tracers '{self.theory.sacc_tracers}' have 2pt data and you have "
f"specified `ell` in the configuration. `ell` is being ignored!",
stacklevel=2,
)
replacement_ells: None | npt.NDArray[np.int64]
window: None | npt.NDArray[np.float64]
replacement_ells, window = extract_window_function(sacc_data, sacc_indices)
if window is not None:
# When using a window function, we do not calculate all Cl's.
# For this reason we have a default set of ells that we use
# to compute Cl's, and we have a set of ells used for
# interpolation.
assert replacement_ells is not None
ells = replacement_ells
else:
if self.theory.ell_or_theta_config is None:
# The SACC file has no data points, just a tracer, in this case we
# are building the statistic from scratch. In this case the user
# must have set the dictionary ell_or_theta, containing the
# minimum, maximum and number of bins to generate the ell values.
raise RuntimeError(
f"Tracers '{self.theory.sacc_tracers}' for data type "
f"'{self.theory.sacc_data_type}' "
"have no 2pt data in the SACC file and no input ell values "
"were given!"
)
ells, Cells = gen.generate_ells_cells(self.theory.ell_or_theta_config)
sacc_indices = None
# When generating the ells and Cells we do not have a window function
window = None
return Cells, ells, sacc_indices, window
[docs]
def get_data_vector(self) -> DataVector:
"""Return this statistic's data vector."""
assert self._data is not None
return self._data
[docs]
def set_data_vector(self, value: DataVector) -> None:
"""Set this statistic's data vector."""
assert value is not None
self._data = value
[docs]
def compute_theory_vector_real_space(self, tools: ModelingTools) -> TheoryVector:
"""Compute a two-point statistic in real space.
This method computes the two-point statistic in real space. It first computes
the Cl's in harmonic space and then translates them to real space using CCL.
"""
assert self.theory.ccl_kind != "cl"
assert self.theory.thetas is not None
assert self.theory.ells_for_xi is not None
tracers0, scale0, tracers1, scale1 = self.theory.get_tracers_and_scales(tools)
# Compute the angular power spectrum (C_ell) at the multipoles specified in
# ells_for_xi. CCL will later interpolate between these values as needed.
if self.theory.apply_interp & ApplyInterpolationWhen.REAL:
ells = self.theory.ells_for_xi
else:
ells = self.theory.interp_ells_gen.generate_all()
cells_for_xi = self.compute_cells(
ells, scale0, scale1, tools, tracers0, tracers1, interpolate=False
)
# Compute the real-space correlation function xi(theta). CCL uses the input
# ells_for_xi and corresponding cells_for_xi, interpolates as needed, and
# performs the Hankel transform to obtain xi at the specified angles.
theory_vector = pyccl.correlation(
tools.get_ccl_cosmology(),
ell=ells,
C_ell=cells_for_xi,
theta=self.theory.thetas / 60,
type=self.theory.ccl_kind,
)
return TheoryVector.create(theory_vector)
[docs]
def compute_theory_vector_harmonic_space(
self, tools: ModelingTools
) -> TheoryVector:
"""Compute a two-point statistic in harmonic space.
This method computes the two-point statistic in harmonic space. It computes
either the Cl's at the ells provided by the SACC file or the ells required
for the window function.
"""
assert self.theory.ccl_kind == "cl"
assert self.theory.ells is not None
tracers0, scale0, tracers1, scale1 = self.theory.get_tracers_and_scales(tools)
if self.theory.window is not None:
# We are using a window function. This means we will have effective
# ells, and effective Cells at those effective ells.
cells = self.compute_cells(
self.theory.ells,
scale0,
scale1,
tools,
tracers0,
tracers1,
interpolate=ApplyInterpolationWhen.HARMONIC_WINDOW
in self.theory.apply_interp,
)
# Here we left multiply the computed Cl's by the window function to get the
# final Cl's.
theory_vector = np.einsum("lb, l -> b", self.theory.window, cells)
# We also compute the mean ell value associated with each bin.
self.theory.mean_ells = np.einsum(
"lb, l -> b", self.theory.window, self.theory.ells
)
assert self._data is not None
return TheoryVector.create(theory_vector)
# If we get here, we are working in harmonic space without a window function.
assert self.theory.ells is not None
theory_vector = self.compute_cells(
self.theory.ells,
scale0,
scale1,
tools,
tracers0,
tracers1,
interpolate=ApplyInterpolationWhen.HARMONIC in self.theory.apply_interp,
)
return TheoryVector.create(theory_vector)
def _compute_theory_vector(self, tools: ModelingTools) -> TheoryVector:
"""Compute a two-point statistic from sources."""
if self.theory.ccl_kind == "cl":
return self.compute_theory_vector_harmonic_space(tools)
return self.compute_theory_vector_real_space(tools)
def _compute_cells_all_orders(
self,
ells: npt.NDArray[np.int64],
scale0: float,
scale1: float,
tools: ModelingTools,
tracers0: Sequence[Tracer],
tracers1: Sequence[Tracer],
) -> npt.NDArray[np.float64]:
"""Compute the power spectrum for the given ells and tracers."""
self.theory.cells = {}
if tracers0 == tracers1:
assert scale0 == scale1
# We should consider how to avoid doing the same calculation twice,
# if possible.
for tracer0, tracer1 in itertools.product(tracers0, tracers1):
pk_name = f"{tracer0.field}:{tracer1.field}"
tn = TracerNames(tracer0.tracer_name, tracer1.tracer_name)
result = calculate_angular_cl(
ells,
pk_name,
scale0,
scale1,
tools,
tracer0,
tracer1,
self.theory.int_options,
)
self.theory.cells[tn] = result
self.theory.cells[mdt.TRACER_NAMES_TOTAL] = np.array(
sum(self.theory.cells.values())
)
theory_vector = self.theory.cells[mdt.TRACER_NAMES_TOTAL]
return theory_vector
def _compute_cells_interpolated(
self,
ells: npt.NDArray[np.int64],
ells_for_interpolation: npt.NDArray[np.int64],
scale0: float,
scale1: float,
tools: ModelingTools,
tracers0: Sequence[Tracer],
tracers1: Sequence[Tracer],
) -> npt.NDArray[np.float64]:
"""Compute the interpolated power spectrum for the given ells and tracers.
:param ells: The angular wavenumbers at which to compute the power spectrum.
:param ells_for_interpolation: The angular wavenumbers at which the power
spectrum is computed for interpolation.
:param scale0: The scale factor for the first tracer.
:param scale1: The scale factor for the second tracer.
:param tools: The modeling tools to use.
:param tracers0: The first tracers to use.
:param tracers1: The second tracers to use.
Compute the power spectrum for the given ells and tracers and interpolate
the result to the ells provided.
:return: The interpolated power spectrum.
"""
computed_cells = self.compute_cells(
ells_for_interpolation, scale0, scale1, tools, tracers0, tracers1
)
cell_interpolator = make_log_interpolator(
ells_for_interpolation, computed_cells
)
cell_interpolated = np.zeros(len(ells))
# We should not interpolate ell 0 and 1
ells_larger_than_1 = ells > 1
cell_interpolated[ells_larger_than_1] = cell_interpolator(
ells[ells_larger_than_1]
)
return cell_interpolated
[docs]
def compute_cells(
self,
ells: npt.NDArray[np.int64],
scale0: float,
scale1: float,
tools: ModelingTools,
tracers0: Sequence[Tracer],
tracers1: Sequence[Tracer],
interpolate: bool = False,
) -> npt.NDArray[np.float64]:
"""Compute the power spectrum for the given ells and tracers.
This method computes the power spectrum for the given ells and tracers. If
interpolate is True, it will interpolate the power spectrum to the ells
provided.
"""
if interpolate:
# ells_for_interpolation are true ells (and thus integral).
# These are the values at which we will have CCL calculate the "exact"
# C_ells: these form our interpolation table.
ells_for_interpolation = self.theory.generate_ells_for_interpolation()
# The call below will calculate the "exact" C_ells (using CCL). Using these
# "exact" C_ells it will then interpolate to determine C_ells at the
# required ell values.
return self._compute_cells_interpolated(
ells,
ells_for_interpolation,
scale0,
scale1,
tools,
tracers0,
tracers1,
)
# No interpolation, all multipoles are computed exactly
return self._compute_cells_all_orders(
ells, scale0, scale1, tools, tracers0, tracers1
)
[docs]
def read_reals(
theory: TwoPointTheory,
sacc_data: sacc.Sacc,
) -> (
None
| tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.int64]]
):
"""Read and return theta and xi.
:param theory: The theory, carrying data type and tracers.
:param sacc_data: The SACC data object to be read.
:return: The theta and xi values.
"""
tracers = theory.sacc_tracers
assert tracers is not None
thetas, xis = sacc_data.get_theta_xi(
theory.sacc_data_type, *tracers, return_cov=False
)
# As version 0.13 of sacc, the method get_real returns the
# theta values and the xi values in arrays of the same length.
assert len(thetas) == len(xis)
common_length = len(thetas)
if common_length == 0:
return None
sacc_indices = np.atleast_1d(sacc_data.indices(theory.sacc_data_type, tracers))
assert sacc_indices is not None # Needed for mypy
assert len(sacc_indices) == common_length
return thetas, xis, sacc_indices
[docs]
def read_ell_cells(
theory: TwoPointTheory,
sacc_data: sacc.Sacc,
) -> (
None | tuple[npt.NDArray[np.int64], npt.NDArray[np.float64], npt.NDArray[np.int64]]
):
"""Read and return ell and Cell.
:param theory: The theory, carrying data type and tracers.
:param sacc_data: The SACC data object to be read.
:return: The ell and Cell values.
"""
tracers = theory.sacc_tracers
assert tracers is not None
ells, cells = sacc_data.get_ell_cl(
theory.sacc_data_type, *tracers, return_cov=False
)
# As version 0.13 of sacc, the method get_ell_cl returns the
# ell values and the Cl values in arrays of the same length.
assert len(ells) == len(cells)
common_length = len(ells)
if common_length == 0:
return None
sacc_indices = np.atleast_1d(sacc_data.indices(theory.sacc_data_type, tracers))
assert sacc_indices is not None # Needed for mypy
assert len(sacc_indices) == common_length
return ells, cells, sacc_indices
[docs]
class TwoPointFactory(BaseModel):
"""Factory class for WeakLensing objects."""
model_config = ConfigDict(extra="forbid", frozen=True)
correlation_space: Annotated[
TwoPointCorrelationSpace,
BeforeValidator(make_correlation_space),
Field(description="The two-point correlation space."),
]
weak_lensing_factories: list[WeakLensingFactory] = Field(default_factory=list)
number_counts_factories: list[NumberCountsFactory] = Field(default_factory=list)
cmb_factories: list[CMBConvergenceFactory] = Field(default_factory=list)
int_options: ClIntegrationOptions | None = None
_wl_factory_map: dict[TypeSource, WeakLensingFactory] = PrivateAttr()
_nc_factory_map: dict[TypeSource, NumberCountsFactory] = PrivateAttr()
_cmb_factory_map: dict[TypeSource, CMBConvergenceFactory] = PrivateAttr()
[docs]
def model_post_init(self, _, /) -> None:
"""Initialize the WeakLensingFactory object."""
self._wl_factory_map: dict[TypeSource, WeakLensingFactory] = {}
self._nc_factory_map: dict[TypeSource, NumberCountsFactory] = {}
self._cmb_factory_map: dict[TypeSource, CMBConvergenceFactory] = {}
for wl_factory in self.weak_lensing_factories:
if wl_factory.type_source in self._wl_factory_map:
raise ValueError(
f"Duplicate WeakLensingFactory found for "
f"type_source {wl_factory.type_source}."
)
self._wl_factory_map[wl_factory.type_source] = wl_factory
for nc_factory in self.number_counts_factories:
if nc_factory.type_source in self._nc_factory_map:
raise ValueError(
f"Duplicate NumberCountsFactory found for "
f"type_source {nc_factory.type_source}."
)
self._nc_factory_map[nc_factory.type_source] = nc_factory
for cmb_factory in self.cmb_factories:
if cmb_factory.type_source in self._cmb_factory_map:
raise ValueError(
f"Duplicate CMBConvergenceFactory found for "
f"type_source {cmb_factory.type_source}."
)
self._cmb_factory_map[cmb_factory.type_source] = cmb_factory
[docs]
def get_factory(
self, measurement: Measurement, type_source: TypeSource = TypeSource.DEFAULT
) -> WeakLensingFactory | NumberCountsFactory | CMBConvergenceFactory:
"""Get the Factory for the given Measurement and TypeSource."""
candidates: Sequence[tuple[tuple[str, ...], dict, str]] = [
(GALAXY_SOURCE_TYPES, self._wl_factory_map, "WeakLensingFactory"),
(GALAXY_LENS_TYPES, self._nc_factory_map, "NumberCountsFactory"),
(CMB_TYPES, self._cmb_factory_map, "CMBConvergenceFactory"),
]
for type_set, factory_map, factory_name in candidates:
if measurement in type_set:
if type_source not in factory_map:
raise ValueError(
f"No {factory_name} found for type_source {type_source}."
)
factory = factory_map[type_source]
return factory
raise ValueError(
f"Factory not found for measurement {measurement}, it is not supported."
)
[docs]
def from_measurement(
self, tpms: list[TwoPointMeasurement]
) -> UpdatableCollection[TwoPoint]:
"""Create a TwoPoint object from a list of TwoPointMeasurement."""
return TwoPoint.from_measurement(measurements=tpms, tp_factory=self)
[docs]
def use_source_factory(
inferred_galaxy_zdist: InferredGalaxyZDist,
measurement: Measurement,
tp_factory: TwoPointFactory,
) -> WeakLensing | NumberCounts | CMBConvergence:
"""Apply the factory to the inferred galaxy redshift distribution."""
if measurement not in inferred_galaxy_zdist.measurements:
raise ValueError(
f"Measurement {measurement} not found in inferred galaxy redshift "
f"distribution {inferred_galaxy_zdist.bin_name}!"
)
source_factory = tp_factory.get_factory(
measurement, inferred_galaxy_zdist.type_source
)
source = source_factory.create(inferred_galaxy_zdist)
return source