Source code for firecrown.utils
"""Some utility functions for patterns common in Firecrown."""
from __future__ import annotations
import numpy as np
import numpy.typing as npt
import sacc
[docs]def upper_triangle_indices(n: int):
"""Returns the upper triangular indices for an (n x n) matrix.
generator that yields a sequence of tuples that carry the indices for an
(n x n) upper-triangular matrix. This is a replacement for the nested loops:
for i in range(n):
for j in range(i, n):
...
"""
for i in range(n):
for j in range(i, n):
yield i, j
[docs]def save_to_sacc(
sacc_data: sacc.Sacc,
data_vector: npt.NDArray[np.float64],
indices: npt.NDArray[np.int64],
strict: bool = True,
) -> sacc.Sacc:
"""Save a data vector into a (new) SACC object, copied from `sacc_data`.
Note that the original object `sacc_data` is not modified. Its contents are
copied into a new object, and the new information is put into that copy,
which is returned by this method.
Arguments
---------
sacc_data: sacc.Sacc
SACC object to be copied. It is not modified.
data_vector: np.ndarray[float]
Data vector to be saved to the new copy of `sacc_data`.
indices: np.ndarray[int]
SACC indices where the data vector should be written.
strict: bool
Whether to check if the data vector covers all the data already present
in the sacc_data.
Returns
-------
new_sacc: sacc.Sacc
A copy of `sacc_data`, with data at `indices` replaced with `data_vector`.
"""
assert len(indices) == len(data_vector)
new_sacc = sacc_data.copy()
if strict:
if set(indices.tolist()) != set(sacc_data.indices()):
raise RuntimeError(
"The data to be saved does not cover all the data in the "
"sacc object. To write only the calculated predictions, "
"set strict=False."
)
for data_idx, sacc_idx in enumerate(indices):
new_sacc.data[sacc_idx].value = data_vector[data_idx]
return new_sacc
[docs]def compare_optional_arrays(x: None | npt.NDArray, y: None | npt.NDArray) -> bool:
"""Compare two arrays, allowing for either or both to be None."""
if x is None and y is None:
return True
if x is not None and y is not None:
return np.array_equal(x, y)
# One is None and the other is not.
return False