"""Some utility functions for patterns common in Firecrown."""
import functools
from collections.abc import Callable, Generator
from enum import Enum, auto
from typing import Annotated, TypeVar
import numpy as np
import pyccl
import sacc
import scipy.interpolate
import yaml
from numpy import typing as npt
from pydantic import BaseModel, BeforeValidator, ConfigDict, field_serializer
from typing_extensions import Self, assert_never
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
ST = TypeVar("ST") # This will be used in YAMLSerializable
[docs]
class YAMLSerializable:
"""Protocol for classes that can be serialized to and from YAML."""
[docs]
def to_yaml(self) -> str:
"""Return the YAML representation of the object."""
return yaml.dump(self, Dumper=Dumper, sort_keys=False)
[docs]
@classmethod
def from_yaml(cls, yaml_str: str) -> Self:
"""Load the object from YAML."""
return yaml.load(yaml_str, Loader=Loader)
[docs]
def base_model_from_yaml(cls: type, yaml_str: str):
"""Create a base model from a yaml string."""
if not issubclass(cls, BaseModel):
raise ValueError("cls must be a subclass of pydantic.BaseModel")
try:
return cls.model_validate(
yaml.safe_load(yaml_str),
strict=True,
)
except Exception as e:
raise ValueError(
f"Error creating {cls.__name__} from yaml. Parsing error message:\n{e}"
) from e
[docs]
def base_model_to_yaml(model: BaseModel) -> str:
"""Convert a base model to a yaml string."""
return yaml.dump(
model.model_dump(), default_flow_style=None, sort_keys=False, width=80
)
[docs]
def upper_triangle_indices(n: int) -> Generator[tuple[int, int], None, None]:
"""Returns the upper triangular indices for an (n x n) matrix.
generator that yields a sequence of tuples that carry the indices for an
(n x n) upper-triangular matrix. This is a replacement for the nested loops:
for i in range(n):
for j in range(i, n):
...
:param n: the size of the matrix
:return: the generator
"""
for i in range(n):
for j in range(i, n):
yield i, j
[docs]
def save_to_sacc(
sacc_data: sacc.Sacc,
data_vector: npt.NDArray[np.float64],
indices: npt.NDArray[np.int64],
strict: bool = True,
) -> sacc.Sacc:
"""Save a data vector into a (new) SACC object, copied from `sacc_data`.
Note that the original object `sacc_data` is not modified. Its contents are
copied into a new object, and the new information is put into that copy,
which is returned by this method.
If `strict` is True (the default), then we must overwrite the entire data
vector. If `strict` is False, then we only overwrite the data at the
specified indices.
:param sacc_data: SACC object to be copied. It is not modified.
:param data_vector: Data vector to be saved to the new copy of `sacc_data`.
:param indices: SACC indices where the data vector should be written.
:param strict: Whether to check if the data vector covers all the data
already present in the sacc_data.
:return: A copy of `sacc_data`, with data at `indices` replaced with `data_vector`.
"""
assert len(indices) == len(data_vector)
new_sacc = sacc_data.copy()
if strict:
if set(indices.ravel().tolist()) != set(sacc_data.indices()):
raise RuntimeError(
"The data to be saved does not cover all the data in the "
"sacc object. To write only the calculated predictions, "
"set strict=False."
)
for data_idx, sacc_idx in enumerate(indices):
new_sacc.data[sacc_idx].value = data_vector[data_idx]
return new_sacc
[docs]
def compare_optional_arrays(x: None | npt.NDArray, y: None | npt.NDArray) -> bool:
"""Compare two arrays, allowing for either or both to be None.
:param x: first array
:param y: second array
:return: whether the arrays are equal
"""
if x is None and y is None:
return True
if x is not None and y is not None:
return np.array_equal(x, y)
# One is None and the other is not.
return False
[docs]
def compare_optionals(x: None | object, y: None | object) -> bool:
"""Compare two objects, allowing for either or both to be None.
:param x: first object
:param y: second object
:return: whether the objects are equal
"""
if x is None and y is None:
return True
if x is not None and y is not None:
return x == y
# One is None and the other is not.
return False
[docs]
class ClLimberMethod(YAMLSerializable, str, Enum):
"""This class defines Cl limber methods."""
@staticmethod
def _generate_next_value_(name, _start, _count, _last_values):
return name.lower()
GSL_QAG_QUAD = auto()
GSL_SPLINE = auto()
def _validate_cl_limber_method(value: ClLimberMethod | str):
if isinstance(value, str):
try:
return ClLimberMethod(value.lower()) # Convert from string to Enum
except ValueError as exc:
raise ValueError(f"Invalid value for ClLimberMethod: {value}") from exc
return value
[docs]
class ClIntegrationMethod(YAMLSerializable, str, Enum):
"""This class defines Cl integration methods."""
@staticmethod
def _generate_next_value_(name, _start, _count, _last_values):
return name.lower()
LIMBER = auto()
FKEM_AUTO = auto()
FKEM_L_LIMBER = auto()
def _validate_cl_integration_method(value: ClIntegrationMethod | str):
if isinstance(value, str):
try:
return ClIntegrationMethod(value.lower()) # Convert from string to Enum
except ValueError as exc:
raise ValueError(f"Invalid value for ClIntegrationMethod: {value}") from exc
return value
[docs]
class ClIntegrationOptions(BaseModel):
"""Options for angular power spectrum integration."""
model_config = ConfigDict(extra="forbid", frozen=True)
method: Annotated[
ClIntegrationMethod, BeforeValidator(_validate_cl_integration_method)
]
limber_method: Annotated[
ClLimberMethod, BeforeValidator(_validate_cl_limber_method)
]
l_limber: int | None = None
limber_max_error: float | None = None
fkem_chi_min: float | None = None
fkem_Nchi: int | None = None
[docs]
@field_serializer("method")
@classmethod
def serialize_method(cls, value: ClIntegrationMethod) -> str:
"""Serialize the method parameter."""
return value.name
[docs]
@field_serializer("limber_method")
@classmethod
def serialize_limber_method(cls, value: ClLimberMethod) -> str:
"""Serialize the limber_method parameter."""
return value.name
[docs]
def model_post_init(self, _, /) -> None:
"""Initialize the WeakLensingFactory object."""
match self.method:
case ClIntegrationMethod.LIMBER:
incompatible_options = [
"limber_max_error",
"l_limber",
"fkem_chi_min",
"fkem_Nchi",
]
case ClIntegrationMethod.FKEM_AUTO:
incompatible_options = ["l_limber"]
case ClIntegrationMethod.FKEM_L_LIMBER:
incompatible_options = ["limber_max_error"]
if self.l_limber is None or self.l_limber < 0:
raise ValueError("l_limber must be set for FKEM_L_LIMBER.")
case _ as unreachable:
assert_never(unreachable)
for option in incompatible_options:
if getattr(self, option) is not None:
raise ValueError(f"{option} is incompatible with {self.method!s}.")
[docs]
def get_angular_cl_args(self):
"""Get the arguments to pass to pyccl.angular_cl."""
match self.limber_method:
case ClLimberMethod.GSL_QAG_QUAD:
arg = {"limber_integration_method": "qag_quad"}
case ClLimberMethod.GSL_SPLINE:
arg = {"limber_integration_method": "spline"}
case _ as unreachable:
assert_never(unreachable)
out: dict[str, str | int | float]
match self.method:
case ClIntegrationMethod.LIMBER:
return arg | {"l_limber": -1}
case ClIntegrationMethod.FKEM_AUTO:
out = {
"l_limber": "auto",
"non_limber_integration_method": "FKEM",
}
if self.limber_max_error is not None:
out["limber_max_error"] = self.limber_max_error
if self.fkem_chi_min is not None:
out["fkem_chi_min"] = self.fkem_chi_min
if self.fkem_Nchi is not None:
out["fkem_Nchi"] = self.fkem_Nchi
return arg | out
case ClIntegrationMethod.FKEM_L_LIMBER:
assert self.l_limber is not None
out = {
"l_limber": self.l_limber,
"non_limber_integration_method": "FKEM",
}
if self.fkem_chi_min is not None:
out["fkem_chi_min"] = self.fkem_chi_min
if self.fkem_Nchi is not None:
out["fkem_Nchi"] = self.fkem_Nchi
return arg | out
case _ as unreachable_method:
assert_never(unreachable_method)
[docs]
@functools.lru_cache(maxsize=128)
def cached_angular_cl(
cosmo: pyccl.Cosmology,
tracers: tuple[pyccl.Tracer, pyccl.Tracer],
ells: npt.NDArray[np.int64],
p_of_k_a=None | Callable[[npt.NDArray[np.int64]], npt.NDArray[np.float64]],
p_of_k_a_lin=None | pyccl.Pk2D | str,
int_options: ClIntegrationOptions | None = None,
):
"""Wrapper for pyccl.angular_cl, with automatic caching.
:param cosmo: the current cosmology
:param tracers: tracers indicating the measurements to be correlated
:param ells: ell values at which to calculate the power spectrum
:param p_of_k_a: function that computes the power spectrum
:param l_limber: the maximum ell for the non-limber integration
:param p_of_k_a_lin: function that returns the linear power spectrum
"""
return pyccl.angular_cl(
cosmo,
tracers[0],
tracers[1],
np.array(ells),
p_of_k_a=p_of_k_a,
p_of_k_a_lin=p_of_k_a_lin,
**(int_options.get_angular_cl_args() if int_options else {}),
)
[docs]
def make_log_interpolator(
x: npt.NDArray[np.int64], y: npt.NDArray[np.float64]
) -> Callable[[npt.NDArray[np.int64]], npt.NDArray[np.float64]]:
"""Return a function object that does 1D spline interpolation.
If all the y values are greater than 0, the function
interpolates log(y) as a function of log(x).
Otherwise, the function interpolates y as a function of log(x).
The resulting interpolater will not extrapolate; if called with
an out-of-range argument it will raise a ValueError.
"""
if np.all(y > 0):
# use log-log interpolation
intp = scipy.interpolate.InterpolatedUnivariateSpline(
np.log(x), np.log(y), ext=2
)
def log_log_interpolator(x_: npt.NDArray[np.int64]) -> npt.NDArray[np.float64]:
"""Interpolate on log-log scale."""
return np.exp(intp(np.log(x_)))
return log_log_interpolator
# only use log for x
intp = scipy.interpolate.InterpolatedUnivariateSpline(np.log(x), y, ext=2)
def log_x_interpolator(x_: npt.NDArray[np.int64]) -> npt.NDArray[np.float64]:
"""Interpolate on log-x scale."""
return intp(np.log(x_))
return log_x_interpolator