Source code for firecrown.metadata_functions
"""This module deals with two-point metadata functions.
It contains functions used to manipulate two-point metadata, including extracting
metadata from a sacc file and creating new metadata objects.
"""
from itertools import combinations_with_replacement, product
from typing import TypedDict, Any
import numpy as np
import numpy.typing as npt
import sacc
from sacc.data_types import required_tags
from firecrown.metadata_types import (
TracerNames,
Measurement,
InferredGalaxyZDist,
TwoPointXY,
TwoPointHarmonic,
TwoPointReal,
LENS_REGEX,
SOURCE_REGEX,
MEASURED_TYPE_STRING_MAP,
measurement_is_compatible,
GALAXY_LENS_TYPES,
GALAXY_SOURCE_TYPES,
ALL_MEASUREMENT_TYPES,
Galaxies,
CMB,
Clusters,
)
# TwoPointRealIndex is a type used to create intermediate objects when reading SACC
# objects. They should not be seen directly by users of Firecrown.
TwoPointRealIndex = TypedDict(
"TwoPointRealIndex",
{
"data_type": str,
"tracer_names": TracerNames,
},
)
# TwoPointHarmonicIndex is a type used to create intermediate objects when reading SACC
# objects. They should not be seen directly by users of Firecrown.
TwoPointHarmonicIndex = TypedDict(
"TwoPointHarmonicIndex",
{
"data_type": str,
"tracer_names": TracerNames,
},
)
[docs]def make_measurement(value: Measurement | dict[str, Any]) -> Measurement:
"""Create a Measurement object from a dictionary."""
if isinstance(value, ALL_MEASUREMENT_TYPES):
return value
if not isinstance(value, dict):
raise ValueError(f"Invalid Measurement: {value} is not a dictionary")
if "subject" not in value:
raise ValueError("Invalid Measurement: dictionary does not contain 'subject'")
subject = value["subject"]
match subject:
case "Galaxies":
return Galaxies[value["property"]]
case "CMB":
return CMB[value["property"]]
case "Clusters":
return Clusters[value["property"]]
case _:
raise ValueError(
f"Invalid Measurement: subject: '{subject}' is not recognized"
)
[docs]def make_measurements(
value: set[Measurement] | list[dict[str, Any]],
) -> set[Measurement]:
"""Create a Measurement object from a dictionary."""
if isinstance(value, set) and all(
isinstance(v, ALL_MEASUREMENT_TYPES) for v in value
):
return value
measurements: set[Measurement] = set()
for measurement_dict in value:
measurements.update([make_measurement(measurement_dict)])
return measurements
[docs]def make_measurement_dict(value: Measurement) -> dict[str, str]:
"""Create a dictionary from a Measurement object.
:param value: the measurement to turn into a dictionary
"""
return {"subject": type(value).__name__, "property": value.name}
[docs]def make_measurements_dict(value: set[Measurement]) -> list[dict[str, str]]:
"""Create a dictionary from a Measurement object.
:param value: the measurement to turn into a dictionary
"""
return [make_measurement_dict(measurement) for measurement in value]
def _extract_all_candidate_measurement_types(
data_points: list[sacc.DataPoint],
include_maybe_types: bool = False,
) -> dict[str, set[Measurement]]:
"""Extract all candidate Measurement from the data points.
The candidate Measurement are the ones that appear in the data points.
"""
all_data_types: set[tuple[str, str, str]] = {
(d.data_type, d.tracers[0], d.tracers[1]) for d in data_points
}
sure_types, maybe_types = _extract_sure_and_maybe_types(all_data_types)
# Remove the sure types from the maybe types.
for tracer0, sure_types0 in sure_types.items():
maybe_types[tracer0] -= sure_types0
# Filter maybe types.
for data_type, tracer1, tracer2 in all_data_types:
if data_type not in MEASURED_TYPE_STRING_MAP:
continue
a, b = MEASURED_TYPE_STRING_MAP[data_type]
if a == b:
continue
if a in sure_types[tracer1] and b in sure_types[tracer2]:
maybe_types[tracer1].discard(b)
maybe_types[tracer2].discard(a)
elif a in sure_types[tracer2] and b in sure_types[tracer1]:
maybe_types[tracer1].discard(a)
maybe_types[tracer2].discard(b)
if include_maybe_types:
return {
tracer0: sure_types0 | maybe_types[tracer0]
for tracer0, sure_types0 in sure_types.items()
}
return sure_types
def _extract_sure_and_maybe_types(all_data_types):
sure_types: dict[str, set[Measurement]] = {}
maybe_types: dict[str, set[Measurement]] = {}
for data_type, tracer1, tracer2 in all_data_types:
sure_types[tracer1] = set()
sure_types[tracer2] = set()
maybe_types[tracer1] = set()
maybe_types[tracer2] = set()
# Getting the sure and maybe types for each tracer.
for data_type, tracer1, tracer2 in all_data_types:
if data_type not in MEASURED_TYPE_STRING_MAP:
continue
a, b = MEASURED_TYPE_STRING_MAP[data_type]
if a == b:
sure_types[tracer1].update({a})
sure_types[tracer2].update({a})
else:
name_match, n1, a, n2, b = match_name_type(tracer1, tracer2, a, b)
if name_match:
sure_types[n1].update({a})
sure_types[n2].update({b})
if not name_match:
maybe_types[tracer1].update({a, b})
maybe_types[tracer2].update({a, b})
return sure_types, maybe_types
[docs]def match_name_type(
tracer1: str,
tracer2: str,
a: Measurement,
b: Measurement,
require_convetion: bool = False,
) -> tuple[bool, str, Measurement, str, Measurement]:
"""Use the naming convention to assign the right measurement to each tracer."""
for n1, n2 in ((tracer1, tracer2), (tracer2, tracer1)):
if LENS_REGEX.match(n1) and SOURCE_REGEX.match(n2):
if a in GALAXY_SOURCE_TYPES and b in GALAXY_LENS_TYPES:
return True, n1, b, n2, a
if b in GALAXY_SOURCE_TYPES and a in GALAXY_LENS_TYPES:
return True, n1, a, n2, b
raise ValueError(
"Invalid SACC file, tracer names do not respect "
"the naming convetion."
)
if require_convetion:
if LENS_REGEX.match(tracer1) and LENS_REGEX.match(tracer2):
return False, tracer1, a, tracer2, b
if SOURCE_REGEX.match(tracer1) and SOURCE_REGEX.match(tracer2):
return False, tracer1, a, tracer2, b
raise ValueError(
f"Invalid tracer names ({tracer1}, {tracer2}) "
f"do not respect the naming convetion."
)
return False, tracer1, a, tracer2, b
[docs]def extract_all_tracers_inferred_galaxy_zdists(
sacc_data: sacc.Sacc, include_maybe_types=False
) -> list[InferredGalaxyZDist]:
"""Extracts the two-point function metadata from a Sacc object.
The Sacc object contains a set of tracers (one-dimensional bins) and data
points (measurements of the correlation between two tracers).
This function extracts the two-point function metadata from the Sacc object
and returns it in a list.
"""
tracers: list[sacc.tracers.BaseTracer] = sacc_data.tracers.values()
tracer_types = extract_all_measured_types(
sacc_data, include_maybe_types=include_maybe_types
)
for tracer0, tracer_types0 in tracer_types.items():
if len(tracer_types0) == 0:
raise ValueError(
f"Tracer {tracer0} does not have data points associated with it. "
f"Inconsistent SACC object."
)
return [
InferredGalaxyZDist(
bin_name=tracer.name,
z=tracer.z,
dndz=tracer.nz,
measurements=tracer_types[tracer.name],
)
for tracer in tracers
]
[docs]def extract_all_measured_types(
sacc_data: sacc.Sacc,
include_maybe_types: bool = False,
) -> dict[str, set[Measurement]]:
"""Extracts the two-point function metadata from a Sacc object.
The Sacc object contains a set of tracers (one-dimensional bins) and data
points (measurements of the correlation between two tracers).
This function extracts the two-point function metadata from the Sacc object
and returns it in a list.
"""
data_points = sacc_data.get_data_points()
return _extract_all_candidate_measurement_types(data_points, include_maybe_types)
[docs]def extract_all_real_metadata_indices(
sacc_data: sacc.Sacc,
allowed_data_type: None | list[str] = None,
) -> list[TwoPointRealIndex]:
"""Extract all two-point function metadata from a sacc file.
Extracts the two-point function measurement metadata for all measurements
made in real space from a Sacc object.
"""
tag_name = "theta"
data_types = sacc_data.get_data_types()
data_types_reals = [
data_type for data_type in data_types if tag_name in required_tags[data_type]
]
if allowed_data_type is not None:
data_types_reals = [
data_type
for data_type in data_types_reals
if data_type in allowed_data_type
]
all_real_indices: list[TwoPointRealIndex] = []
for data_type in data_types_reals:
for combo in sacc_data.get_tracer_combinations(data_type):
if len(combo) != 2:
raise ValueError(
f"Tracer combination {combo} does not have exactly two tracers."
)
all_real_indices.append(
{
"data_type": data_type,
"tracer_names": TracerNames(*combo),
}
)
return all_real_indices
[docs]def extract_all_harmonic_metadata_indices(
sacc_data: sacc.Sacc, allowed_data_type: None | list[str] = None
) -> list[TwoPointHarmonicIndex]:
"""Extracts the two-point function metadata from a sacc file."""
tag_name = "ell"
data_types = sacc_data.get_data_types()
data_types_cells = [
data_type for data_type in data_types if tag_name in required_tags[data_type]
]
if allowed_data_type is not None:
data_types_cells = [
data_type
for data_type in data_types_cells
if data_type in allowed_data_type
]
all_harmonic_indices: list[TwoPointHarmonicIndex] = []
for data_type in data_types_cells:
for combo in sacc_data.get_tracer_combinations(data_type):
if len(combo) != 2:
raise ValueError(
f"Tracer combination {combo} does not have exactly two tracers."
)
all_harmonic_indices.append(
{
"data_type": data_type,
"tracer_names": TracerNames(*combo),
}
)
return all_harmonic_indices
[docs]def extract_all_harmonic_metadata(
sacc_data: sacc.Sacc,
allowed_data_type: None | list[str] = None,
include_maybe_types=False,
) -> list[TwoPointHarmonic]:
"""Extract the two-point function metadata and data from a sacc file."""
inferred_galaxy_zdists_dict = {
igz.bin_name: igz
for igz in extract_all_tracers_inferred_galaxy_zdists(
sacc_data, include_maybe_types=include_maybe_types
)
}
result: list[TwoPointHarmonic] = []
for cell_index in extract_all_harmonic_metadata_indices(
sacc_data, allowed_data_type
):
tracer_names = cell_index["tracer_names"]
dt = cell_index["data_type"]
XY = make_two_point_xy(inferred_galaxy_zdists_dict, tracer_names, dt)
t1, t2 = tracer_names
ells, _, indices = sacc_data.get_ell_cl(
data_type=dt, tracer1=t1, tracer2=t2, return_cov=False, return_ind=True
)
replacement_ells, weights = extract_window_function(sacc_data, indices)
if replacement_ells is not None:
ells = replacement_ells
result.append(TwoPointHarmonic(XY=XY, window=weights, ells=ells))
return result
# Extracting all real metadata from a SACC object.
[docs]def extract_all_real_metadata(
sacc_data: sacc.Sacc,
allowed_data_type: None | list[str] = None,
include_maybe_types=False,
) -> list[TwoPointReal]:
"""Extract the two-point function metadata and data from a sacc file."""
inferred_galaxy_zdists_dict = {
igz.bin_name: igz
for igz in extract_all_tracers_inferred_galaxy_zdists(
sacc_data, include_maybe_types=include_maybe_types
)
}
tprs: list[TwoPointReal] = []
for real_index in extract_all_real_metadata_indices(sacc_data, allowed_data_type):
tracer_names = real_index["tracer_names"]
dt = real_index["data_type"]
XY = make_two_point_xy(inferred_galaxy_zdists_dict, tracer_names, dt)
t1, t2 = tracer_names
thetas, _, _ = sacc_data.get_theta_xi(
data_type=dt, tracer1=t1, tracer2=t2, return_cov=False, return_ind=True
)
tprs.append(TwoPointReal(XY=XY, thetas=thetas))
return tprs
[docs]def extract_all_photoz_bin_combinations(
sacc_data: sacc.Sacc,
include_maybe_types: bool = False,
) -> list[TwoPointXY]:
"""Extracts the two-point function metadata from a sacc file."""
inferred_galaxy_zdists = extract_all_tracers_inferred_galaxy_zdists(
sacc_data, include_maybe_types=include_maybe_types
)
bin_combinations = make_all_photoz_bin_combinations(inferred_galaxy_zdists)
return bin_combinations
[docs]def extract_window_function(
sacc_data: sacc.Sacc, indices: npt.NDArray[np.int64]
) -> tuple[None | npt.NDArray[np.int64], None | npt.NDArray[np.float64]]:
"""Extract ells and weights for a window function.
:params sacc_data: the Sacc object from which we read.
:params indices: the indices of the data points in the Sacc object which
are computed by the window function.
:returns: the ells and weights of the window function that match the
given indices from a sacc object, or a tuple of (None, None)
if the indices represent the measured Cells directly.
"""
bandpower_window = sacc_data.get_bandpower_windows(indices)
if bandpower_window is None:
return None, None
ells = bandpower_window.values
weights = bandpower_window.weight / bandpower_window.weight.sum(axis=0)
return ells, weights
[docs]def make_all_photoz_bin_combinations(
inferred_galaxy_zdists: list[InferredGalaxyZDist],
) -> list[TwoPointXY]:
"""Extract the two-point function metadata from a sacc file."""
bin_combinations = [
TwoPointXY(
x=igz1, y=igz2, x_measurement=x_measurement, y_measurement=y_measurement
)
for igz1, igz2 in combinations_with_replacement(inferred_galaxy_zdists, 2)
for x_measurement, y_measurement in product(
igz1.measurements, igz2.measurements
)
if measurement_is_compatible(x_measurement, y_measurement)
]
return bin_combinations
[docs]def make_two_point_xy(
inferred_galaxy_zdists_dict: dict[str, InferredGalaxyZDist],
tracer_names: TracerNames,
data_type: str,
) -> TwoPointXY:
"""Build a TwoPointXY object from the inferred galaxy z distributions.
The TwoPointXY object is built from the inferred galaxy z distributions, the data
type, and the tracer names.
:param inferred_galaxy_zdists_dict: a dictionary of inferred galaxy z distributions.
:param tracer_names: a tuple of tracer names.
:param data_type: the data type.
:returns: a TwoPointXY object.
"""
a, b = MEASURED_TYPE_STRING_MAP[data_type]
igz1 = inferred_galaxy_zdists_dict[tracer_names[0]]
igz2 = inferred_galaxy_zdists_dict[tracer_names[1]]
ab = a in igz1.measurements and b in igz2.measurements
ba = b in igz1.measurements and a in igz2.measurements
if a != b and ab and ba:
raise ValueError(
f"Ambiguous measurements for tracers {tracer_names}. "
f"Impossible to determine which measurement is from which tracer."
)
XY = TwoPointXY(
x=igz1, y=igz2, x_measurement=a if ab else b, y_measurement=b if ab else a
)
return XY
[docs]def measurements_from_index(
index: TwoPointRealIndex | TwoPointHarmonicIndex,
) -> tuple[str, Measurement, str, Measurement]:
"""Return the measurements from a TwoPointXiThetaIndex object."""
a, b = MEASURED_TYPE_STRING_MAP[index["data_type"]]
_, n1, a, n2, b = match_name_type(
index["tracer_names"].name1,
index["tracer_names"].name2,
a,
b,
require_convetion=True,
)
return n1, a, n2, b