Source code for firecrown.utils

"""Some utility functions for patterns common in Firecrown."""

from typing import Generator, TypeVar, Type, Callable, Annotated
from enum import Enum, auto

import functools
from typing_extensions import assert_never
import numpy as np
import pyccl
import scipy.interpolate
from numpy import typing as npt
from pydantic import BaseModel, ConfigDict, BeforeValidator, field_serializer

import sacc

import yaml
from yaml import CLoader as Loader
from yaml import CDumper as Dumper

ST = TypeVar("ST")  # This will be used in YAMLSerializable


[docs] class YAMLSerializable: """Protocol for classes that can be serialized to and from YAML."""
[docs] def to_yaml(self: ST) -> str: """Return the YAML representation of the object.""" return yaml.dump(self, Dumper=Dumper, sort_keys=False)
[docs] @classmethod def from_yaml(cls: Type[ST], yaml_str: str) -> ST: """Load the object from YAML.""" return yaml.load(yaml_str, Loader=Loader)
[docs] def base_model_from_yaml(cls: type, yaml_str: str): """Create a base model from a yaml string.""" if not issubclass(cls, BaseModel): raise ValueError("cls must be a subclass of pydantic.BaseModel") try: return cls.model_validate( yaml.safe_load(yaml_str), strict=True, ) except Exception as e: raise ValueError( f"Error creating {cls.__name__} from yaml. Parsing error message:\n{e}" ) from e
[docs] def base_model_to_yaml(model: BaseModel) -> str: """Convert a base model to a yaml string.""" return yaml.dump( model.model_dump(), default_flow_style=None, sort_keys=False, width=80 )
[docs] def upper_triangle_indices(n: int) -> Generator[tuple[int, int], None, None]: """Returns the upper triangular indices for an (n x n) matrix. generator that yields a sequence of tuples that carry the indices for an (n x n) upper-triangular matrix. This is a replacement for the nested loops: for i in range(n): for j in range(i, n): ... :param n: the size of the matrix :return: the generator """ for i in range(n): for j in range(i, n): yield i, j
[docs] def save_to_sacc( sacc_data: sacc.Sacc, data_vector: npt.NDArray[np.float64], indices: npt.NDArray[np.int64], strict: bool = True, ) -> sacc.Sacc: """Save a data vector into a (new) SACC object, copied from `sacc_data`. Note that the original object `sacc_data` is not modified. Its contents are copied into a new object, and the new information is put into that copy, which is returned by this method. If `strict` is True (the default), then we must overwrite the entire data vector. If `strict` is False, then we only overwrite the data at the specified indices. :param sacc_data: SACC object to be copied. It is not modified. :param data_vector: Data vector to be saved to the new copy of `sacc_data`. :param indices: SACC indices where the data vector should be written. :param strict: Whether to check if the data vector covers all the data already present in the sacc_data. :return: A copy of `sacc_data`, with data at `indices` replaced with `data_vector`. """ assert len(indices) == len(data_vector) new_sacc = sacc_data.copy() if strict: if set(indices.ravel().tolist()) != set(sacc_data.indices()): raise RuntimeError( "The data to be saved does not cover all the data in the " "sacc object. To write only the calculated predictions, " "set strict=False." ) for data_idx, sacc_idx in enumerate(indices): new_sacc.data[sacc_idx].value = data_vector[data_idx] return new_sacc
[docs] def compare_optional_arrays(x: None | npt.NDArray, y: None | npt.NDArray) -> bool: """Compare two arrays, allowing for either or both to be None. :param x: first array :param y: second array :return: whether the arrays are equal """ if x is None and y is None: return True if x is not None and y is not None: return np.array_equal(x, y) # One is None and the other is not. return False
[docs] def compare_optionals(x: None | object, y: None | object) -> bool: """Compare two objects, allowing for either or both to be None. :param x: first object :param y: second object :return: whether the objects are equal """ if x is None and y is None: return True if x is not None and y is not None: return x == y # One is None and the other is not. return False
[docs] class ClLimberMethod(YAMLSerializable, str, Enum): """This class defines Cl limber methods.""" @staticmethod def _generate_next_value_(name, _start, _count, _last_values): return name.lower() GSL_QAG_QUAD = auto() GSL_SPLINE = auto()
def _validate_cl_limber_method(value: ClLimberMethod | str): if isinstance(value, str): try: return ClLimberMethod(value.lower()) # Convert from string to Enum except ValueError as exc: raise ValueError(f"Invalid value for ClLimberMethod: {value}") from exc return value
[docs] class ClIntegrationMethod(YAMLSerializable, str, Enum): """This class defines Cl integration methods.""" @staticmethod def _generate_next_value_(name, _start, _count, _last_values): return name.lower() LIMBER = auto() FKEM_AUTO = auto() FKEM_L_LIMBER = auto()
def _validate_cl_integration_method(value: ClIntegrationMethod | str): if isinstance(value, str): try: return ClIntegrationMethod(value.lower()) # Convert from string to Enum except ValueError as exc: raise ValueError(f"Invalid value for ClIntegrationMethod: {value}") from exc return value
[docs] class ClIntegrationOptions(BaseModel): """Options for angular power spectrum integration.""" model_config = ConfigDict(extra="forbid", frozen=True) method: Annotated[ ClIntegrationMethod, BeforeValidator(_validate_cl_integration_method) ] limber_method: Annotated[ ClLimberMethod, BeforeValidator(_validate_cl_limber_method) ] l_limber: int | None = None limber_max_error: float | None = None fkem_chi_min: float | None = None fkem_Nchi: int | None = None
[docs] @field_serializer("method") @classmethod def serialize_method(cls, value: ClIntegrationMethod) -> str: """Serialize the method parameter.""" return value.name
[docs] @field_serializer("limber_method") @classmethod def serialize_limber_method(cls, value: ClLimberMethod) -> str: """Serialize the limber_method parameter.""" return value.name
[docs] def model_post_init(self, _, /) -> None: """Initialize the WeakLensingFactory object.""" match self.method: case ClIntegrationMethod.LIMBER: incompatible_options = [ "limber_max_error", "l_limber", "fkem_chi_min", "fkem_Nchi", ] case ClIntegrationMethod.FKEM_AUTO: incompatible_options = ["l_limber"] case ClIntegrationMethod.FKEM_L_LIMBER: incompatible_options = ["limber_max_error"] if self.l_limber is None or self.l_limber < 0: raise ValueError("l_limber must be set for FKEM_L_LIMBER.") case _ as unreachable: assert_never(unreachable) for option in incompatible_options: if getattr(self, option) is not None: raise ValueError(f"{option} is incompatible with {str(self.method)}.")
[docs] def get_angular_cl_args(self): """Get the arguments to pass to pyccl.angular_cl.""" match self.limber_method: case ClLimberMethod.GSL_QAG_QUAD: arg = {"limber_integration_method": "qag_quad"} case ClLimberMethod.GSL_SPLINE: arg = {"limber_integration_method": "spline"} case _ as unreachable: assert_never(unreachable) out: dict[str, str | int | float] match self.method: case ClIntegrationMethod.LIMBER: return arg | {"l_limber": -1} case ClIntegrationMethod.FKEM_AUTO: out = { "l_limber": "auto", "non_limber_integration_method": "FKEM", } if self.limber_max_error is not None: out["limber_max_error"] = self.limber_max_error if self.fkem_chi_min is not None: out["fkem_chi_min"] = self.fkem_chi_min if self.fkem_Nchi is not None: out["fkem_Nchi"] = self.fkem_Nchi return arg | out case ClIntegrationMethod.FKEM_L_LIMBER: assert self.l_limber is not None out = { "l_limber": self.l_limber, "non_limber_integration_method": "FKEM", } if self.fkem_chi_min is not None: out["fkem_chi_min"] = self.fkem_chi_min if self.fkem_Nchi is not None: out["fkem_Nchi"] = self.fkem_Nchi return arg | out case _ as unreachable_method: assert_never(unreachable_method)
[docs] @functools.lru_cache(maxsize=128) def cached_angular_cl( cosmo: pyccl.Cosmology, tracers: tuple[pyccl.Tracer, pyccl.Tracer], ells: npt.NDArray[np.int64], p_of_k_a=None | Callable[[npt.NDArray[np.int64]], npt.NDArray[np.float64]], p_of_k_a_lin=None | pyccl.Pk2D | str, int_options: ClIntegrationOptions | None = None, ): """Wrapper for pyccl.angular_cl, with automatic caching. :param cosmo: the current cosmology :param tracers: tracers indicating the measurements to be correlated :param ells: ell values at which to calculate the power spectrum :param p_of_k_a: function that computes the power spectrum :param l_limber: the maximum ell for the non-limber integration :param p_of_k_a_lin: function that returns the linear power spectrum """ return pyccl.angular_cl( cosmo, tracers[0], tracers[1], np.array(ells), p_of_k_a=p_of_k_a, p_of_k_a_lin=p_of_k_a_lin, **(int_options.get_angular_cl_args() if int_options else {}), )
[docs] def make_log_interpolator( x: npt.NDArray[np.int64], y: npt.NDArray[np.float64] ) -> Callable[[npt.NDArray[np.int64]], npt.NDArray[np.float64]]: """Return a function object that does 1D spline interpolation. If all the y values are greater than 0, the function interpolates log(y) as a function of log(x). Otherwise, the function interpolates y as a function of log(x). The resulting interpolater will not extrapolate; if called with an out-of-range argument it will raise a ValueError. """ if np.all(y > 0): # use log-log interpolation intp = scipy.interpolate.InterpolatedUnivariateSpline( np.log(x), np.log(y), ext=2 ) def log_log_interpolator(x_: npt.NDArray[np.int64]) -> npt.NDArray[np.float64]: """Interpolate on log-log scale.""" return np.exp(intp(np.log(x_))) return log_log_interpolator # only use log for x intp = scipy.interpolate.InterpolatedUnivariateSpline(np.log(x), y, ext=2) def log_x_interpolator(x_: npt.NDArray[np.int64]) -> npt.NDArray[np.float64]: """Interpolate on log-x scale.""" return intp(np.log(x_)) return log_x_interpolator