"""Abstract base classes for TwoPoint Statistics sources."""
from __future__ import annotations
from typing import Optional, Sequence, final, TypeVar, Generic
from abc import abstractmethod
from dataclasses import dataclass, replace
import numpy as np
import numpy.typing as npt
from scipy.interpolate import Akima1DInterpolator
import sacc
import pyccl
import pyccl.nl_pt
from .....modeling_tools import ModelingTools
from .....parameters import ParamsMap
from ..... import parameters
from .....updatable import Updatable, UpdatableCollection
[docs]class SourceSystematic(Updatable):
"""An abstract systematic class (e.g., shear biases, photo-z shifts, etc.).
This class currently has no methods at all, because the argument types for
the `apply` method of different subclasses are different.
"""
[docs] def read(self, sacc_data: sacc.Sacc):
"""Call to allow this object to read from the appropriate sacc data."""
[docs]class Source(Updatable):
"""An abstract source class (e.g., a sample of lenses).
Parameters
----------
systematics : list of str, optional
A list of the source-level systematics to apply to the source. The
default of `None` implies no systematics.
"""
systematics: Sequence[SourceSystematic]
cosmo_hash: Optional[int]
tracers: Sequence[Tracer]
def __init__(self, sacc_tracer: str) -> None:
"""Create a Source object that uses the named tracer.
:param sacc_tracer: the name of the tracer in the SACC file. This is used
as a prefix for its parameters.
"""
super().__init__(parameter_prefix=sacc_tracer)
self.sacc_tracer = sacc_tracer
[docs] @final
def read(self, sacc_data: sacc.Sacc):
"""Read the data for this source from the SACC file."""
if hasattr(self, "systematics"):
for systematic in self.systematics:
systematic.read(sacc_data)
self._read(sacc_data)
[docs] @abstractmethod
def _read(self, sacc_data: sacc.Sacc):
"""Abstract method to read the data for this source from the SACC file."""
[docs] def _update_source(self, params: ParamsMap):
"""Method to update the source from the given ParamsMap.
Any subclass that needs to do more than update its contained :class:`Updatable`
objects should implement this method.
"""
[docs] @final
def _update(self, params: ParamsMap):
"""Implementation of Updatable interface method `_update`.
This clears the current hash and tracer, and calls the abstract method
`_update_source`, which must be implemented in all subclasses.
"""
self.cosmo_hash = None
self.tracers = []
self._update_source(params)
[docs] @abstractmethod
def get_scale(self) -> float:
"""Abstract method to return the scales for this `Source`."""
[docs] @abstractmethod
def create_tracers(self, tools: ModelingTools):
"""Create tracers for this `Source`, for the given cosmology."""
[docs] @final
def get_tracers(self, tools: ModelingTools) -> Sequence[Tracer]:
"""Return the tracer for the given cosmology.
This method caches its result, so if called a second time with the same
cosmology, no calculation needs to be done.
"""
ccl_cosmo = tools.get_ccl_cosmology()
cur_hash = hash(ccl_cosmo)
if hasattr(self, "cosmo_hash") and self.cosmo_hash == cur_hash:
return self.tracers
self.tracers, _ = self.create_tracers(tools)
self.cosmo_hash = cur_hash
return self.tracers
[docs]class Tracer:
"""Extending the pyccl.Tracer object with additional information.
Bundles together a pyccl.Tracer object with optional information about the
underlying 3D field, a pyccl.nl_pt.PTTracer, and halo profiles.
"""
[docs] @staticmethod
def determine_field_name(field: Optional[str], tracer: Optional[str]) -> str:
"""Gets a field name for a tracer.
This function encapsulates the policy for determining the value to be
assigned to the :attr:`field` attribute of a :class:`Tracer`.
It is a static method only to keep it grouped with the class for which it is
defining the initialization policy.
"""
if field is not None:
return field
if tracer is not None:
return tracer
return "delta_matter"
def __init__(
self,
tracer: pyccl.Tracer,
tracer_name: Optional[str] = None,
field: Optional[str] = None,
pt_tracer: Optional[pyccl.nl_pt.PTTracer] = None,
halo_profile: Optional[pyccl.halos.HaloProfile] = None,
halo_2pt: Optional[pyccl.halos.Profile2pt] = None,
):
"""Initialize a new Tracer based on the pyccl.Tracer which must not be None.
Note that the :class:`pyccl.Tracer` is not copied; we store a reference to the
original tracer. Be careful not to accidentally share :class:`pyccl.Tracer`s.
If no tracer_name is supplied, then the tracer_name is set to the name of the
:class:`pyccl.Tracer` class that was used.
If no `field` is given, then the attribute :attr:`field` is set to either
(1) the tracer_name, if one was given, or (2) 'delta_matter'.
"""
assert tracer is not None
self.ccl_tracer = tracer
self.tracer_name: str = tracer_name or tracer.__class__.__name__
self.field = Tracer.determine_field_name(field, tracer_name)
self.pt_tracer = pt_tracer
self.halo_profile = halo_profile
self.halo_2pt = halo_2pt
@property
def has_pt(self) -> bool:
"""Return True if we have a pt_tracer, and False if not."""
return self.pt_tracer is not None
@property
def has_hm(self) -> bool:
"""Return True if we have a halo_profile, and False if not."""
return self.halo_profile is not None
# Sources of galaxy distributions
[docs]@dataclass(frozen=True)
class SourceGalaxyArgs:
"""Class for galaxy based sources arguments."""
z: npt.NDArray[np.float64]
dndz: npt.NDArray[np.float64]
scale: float = 1.0
field: str = "delta_matter"
_SourceGalaxyArgsT = TypeVar("_SourceGalaxyArgsT", bound=SourceGalaxyArgs)
[docs]class SourceGalaxySystematic(SourceSystematic, Generic[_SourceGalaxyArgsT]):
"""Abstract base class for all galaxy based source systematics."""
[docs] @abstractmethod
def apply(
self, tools: ModelingTools, tracer_arg: _SourceGalaxyArgsT
) -> _SourceGalaxyArgsT:
"""Apply method to include systematics in the tracer_arg."""
_SourceGalaxySystematicT = TypeVar(
"_SourceGalaxySystematicT", bound=SourceGalaxySystematic
)
[docs]class SourceGalaxyPhotoZShift(
SourceGalaxySystematic[_SourceGalaxyArgsT], Generic[_SourceGalaxyArgsT]
):
"""A photo-z shift bias.
This systematic shifts the photo-z distribution by some amount `delta_z`.
The following parameters are special Updatable parameters, which means that
they can be updated by the sampler, sacc_tracer is going to be used as a
prefix for the parameters:
:ivar delta_z: the photo-z shift.
"""
def __init__(self, sacc_tracer: str) -> None:
"""Create a PhotoZShift object, using the specified tracer name.
:param sacc_tracer: the name of the tracer in the SACC file. This is used
as a prefix for its parameters.
"""
super().__init__(parameter_prefix=sacc_tracer)
self.delta_z = parameters.register_new_updatable_parameter()
[docs] def apply(self, tools: ModelingTools, tracer_arg: _SourceGalaxyArgsT):
"""Apply a shift to the photo-z distribution of a source."""
dndz_interp = Akima1DInterpolator(tracer_arg.z, tracer_arg.dndz)
dndz = dndz_interp(tracer_arg.z - self.delta_z, extrapolate=False)
dndz[np.isnan(dndz)] = 0.0
return replace(
tracer_arg,
dndz=dndz,
)
[docs]class SourceGalaxySelectField(
SourceGalaxySystematic[_SourceGalaxyArgsT], Generic[_SourceGalaxyArgsT]
):
"""The source galaxy select field systematic.
A systematic that allows specifying the 3D field that will be used
to select the 3D power spectrum when computing the angular power
spectrum.
"""
def __init__(self, field: str = "delta_matter"):
"""Specify which 3D field should be used when computing angular power spectra.
:param field: the name of the 3D field that is associated to the tracer.
Default: `"delta_matter"`
"""
super().__init__()
self.field = field
[docs] def apply(
self, tools: ModelingTools, tracer_arg: _SourceGalaxyArgsT
) -> _SourceGalaxyArgsT:
"""Apply method to include systematics in the tracer_arg."""
return replace(tracer_arg, field=self.field)
[docs]class SourceGalaxy(Source, Generic[_SourceGalaxyArgsT]):
"""Source class for galaxy based sources."""
def __init__(
self,
*,
sacc_tracer: str,
systematics: Optional[list[SourceGalaxySystematic]] = None,
):
"""Initialize the SourceGalaxy object.
:param sacc_tracer: the name of the tracer in the SACC file. This is used
as a prefix for its parameters.
"""
super().__init__(sacc_tracer)
self.sacc_tracer = sacc_tracer
self.current_tracer_args: Optional[_SourceGalaxyArgsT] = None
self.systematics: UpdatableCollection[SourceGalaxySystematic] = (
UpdatableCollection(systematics)
)
self.tracer_args: _SourceGalaxyArgsT
[docs] def _read(self, sacc_data: sacc.Sacc):
"""Read the galaxy redshift distribution model from a sacc file.
All derived classes must call this method in their own `_read` method
after they have read their own data and initialized their tracer_args.
"""
try:
tracer_args = self.tracer_args
except AttributeError as exc:
raise RuntimeError(
"Must initialize tracer_args before calling _read on SourceGalaxy"
) from exc
tracer = sacc_data.get_tracer(self.sacc_tracer)
z = getattr(tracer, "z").copy().flatten()
nz = getattr(tracer, "nz").copy().flatten()
indices = np.argsort(z)
z = z[indices]
nz = nz[indices]
self.tracer_args = replace(
tracer_args,
z=z,
dndz=nz,
)