from functools import partial
import astropy.units as u
import jax
import jax.numpy as jnp
import nested_pandas as npd
import numpy as np
import pandas as pd
from lsdb.catalog.map_catalog import MapCatalog
from lsdb.core.search.moc_search import MOCSearch
from mocpy import MOC
from photod.parameters import GlobalParams
from photod.priors import getPriorMapIndex, initializePriorGrid, make3Dprior
from photod.results import BayesResults
from photod.stats import Entropy, getMargDistr3D, getPosteriorQuantiles, getQrQuantiles
from photod.column_map.base import mapper_from_glossary
from pathlib import Path
[docs]
def set_column_mapping(variable_mapping: Path):
"""
Set the column mapping used for this module. Updates the module-global 'cc'
ColumnMap object.
"""
global cc
cc = mapper_from_glossary("CatalogColumnMap", "Reference catalog columns", variable_mapping)
set_column_mapping(Path(__file__).parent / "column_map" / "variables.yaml")
[docs]
def makeBayesEstimates3d(
starsData: npd.NestedFrame,
priorGrid: np.ndarray,
globalParams: GlobalParams,
batchSize: int = 100,
returnPosteriors: bool = False,
):
"""Compute the bayes statistics for all the stars in a catalog partition.
Used for fast, large-scale processing. It leverages parallelization with JAX.
"""
colorsAndIndices = getColorsAndPriorIndices(starsData, globalParams)
# Use `jax.lax.map` to batch computations with scan/vmap and use memory efficiently.
# The BayesResult object is populated with the chi2min and statistics for each star.
# If `returnAllInfo` is True, the prior and posterior arrays will be included in the results.
func = partial(
loopOverEachStar,
priorGrid=priorGrid,
globalParams=globalParams.getArgs(),
returnPosteriors=returnPosteriors,
)
results = BayesResults(*jax.lax.map(func, colorsAndIndices, batch_size=batchSize))
# Create the DataFrame with the expectation values and uncertainties
estimatesDf = pd.DataFrame(
{
cc.galactic_longitude: starsData[cc.galactic_longitude],
cc.galactic_latitude: starsData[cc.galactic_latitude],
cc.chi_sq_min: results.chi2min,
**results.statistics,
}
)
return estimatesDf, results
[docs]
def makeBayesPosteriors3d(
starsData: npd.NestedFrame,
mapCatalog: MapCatalog,
globalParams: GlobalParams,
):
"""Compute the bayes posteriors for all the stars in a DataFrame.
The posterior arrays are extremely large in memory and, therefore, this method should
only be used with a handful set of stars.
It does not use JAX for parallelization.
"""
maxMapOrder = mapCatalog.hc_structure.pixel_tree.get_max_depth()
results = []
for index in range(len(starsData)):
star = starsData.iloc[[index]]
ra = star[cc.right_ascension].to_numpy() * u.deg
dec = star[cc.declination].to_numpy() * u.deg
mapMoc = MOC.from_lonlat(ra, dec, max_norder=maxMapOrder)
# Find the map partition corresponding to each star
mapPartitionDf = mapCatalog.search(MOCSearch(mapMoc, fine=False)).compute()
priorGrid = initializePriorGrid(mapPartitionDf, globalParams)
priorGrid = jax.numpy.array(list(priorGrid.values()))
_, starResults = makeBayesEstimates3d(star, priorGrid, globalParams, returnPosteriors=True)
results.append(starResults)
return results
[docs]
def getColorsAndPriorIndices(catalog, params):
colors = catalog[list(params.fitColors)].to_numpy(dtype=np.float64)
colorErrNames = [color + "Err" for color in params.fitColors]
colorsErr = catalog[colorErrNames].to_numpy(dtype=np.float64)
priorIndices = jnp.array(getPriorMapIndex(catalog[cc.observed_mag_r]))
return colors, colorsErr, priorIndices
@partial(jax.jit, static_argnames="returnPosteriors")
[docs]
def loopOverEachStar(starData, priorGrid, globalParams, returnPosteriors):
"""Internal method with the logic to be run for each star."""
colors, colorsErr, priorIndices = starData
locusColors, Ar1d, FeH1d, Mr1d, dFeH, dMr, QrGrid, QrIndices = globalParams
chi2map = calculateChi2(colors, colorsErr, locusColors)
dAr, likeCube, priorCube, chi2min = likeAndPrior(Ar1d, FeH1d, Mr1d, chi2map, priorGrid, priorIndices)
postCube = priorCube * likeCube
margPost = getMargPosteriors(priorCube, likeCube, postCube, dMr, dFeH, dAr)
statistics = postProcess(Ar1d, FeH1d, Mr1d, postCube, QrGrid, QrIndices, *margPost)
otherInfo = [likeCube, priorCube, postCube, *margPost] if returnPosteriors else []
return chi2min, statistics, *otherInfo
[docs]
def calculateChi2(colors, colorsErr, locusColors):
"""Compute chi-squared map using provided 3D model locus."""
# Remove the last axis (color), hence axis=-1
return jnp.sum(jnp.square((colors - locusColors) / colorsErr), axis=-1)
[docs]
def likeAndPrior(Ar1d, FeH1d, Mr1d, chi2map, priorGrid, priorIndices):
"""Compute the likelihood map, the 3D (Mr, FeH, Ar) prior from 2D (Mr, FeH) prior
using uniform prior for Ar, and the chi_sq_min."""
likeGrid = jnp.exp(-0.5 * chi2map)
likeCube = likeGrid.reshape(FeH1d.size, Mr1d.size, Ar1d.size)
dAr = Ar1d[1] - Ar1d[0] if Ar1d.size > 1 else 0.01
## generate 3D (Mr, FeH, Ar) prior from 2D (Mr, FeH) prior using uniform prior for Ar
prior2d = priorGrid[priorIndices].reshape(FeH1d.size, Mr1d.size)
priorCube = make3Dprior(prior2d, Ar1d.size)
return dAr, likeCube, priorCube, jnp.min(chi2map)
[docs]
def getMargPosteriors(priorCube, likeCube, postCube, dMr, dFeH, dAr):
"""Get posterior information"""
margpostMr, margpostFeH, margpostAr = {}, {}, {}
for idx, cube in enumerate([priorCube, likeCube, postCube]):
distrMr, distrFeH, distrAr = getMargDistr3D(cube, dMr, dFeH, dAr)
margpostMr[idx] = distrMr
margpostFeH[idx] = distrFeH
margpostAr[idx] = distrAr
return margpostMr, margpostFeH, margpostAr
[docs]
def postProcess(Ar1d, FeH1d, Mr1d, postCube, QrGrid, QrIndices, margpostMr, margpostFeH, margpostAr):
"""Get expectation values and uncertainties marginalize and get statistics."""
MrQuantiles = getPosteriorQuantiles(Mr1d, margpostMr[2])
FeHQuantiles = getPosteriorQuantiles(FeH1d, margpostFeH[2])
ArQuantiles = getPosteriorQuantiles(Ar1d, margpostAr[2])
QrQuantiles = getQrQuantiles(postCube, QrGrid, QrIndices)
# Calculate the quantiles for Mr, FeH and Ar and add them as columns to the result
quantile_names = ["lo", "median", "hi"]
posteriorsDict = {
f"{statisticsName}_quantile_{quantile_names[i]}": quantile
for quantiles, statisticsName in zip(
[MrQuantiles, FeHQuantiles, ArQuantiles, QrQuantiles],
[cc.abs_mag_r, cc.metallicity, cc.extinction_r, cc.abs_mag_ext_r],
)
for i, quantile in enumerate(quantiles)
}
return {
**posteriorsDict,
cc.abs_mag_r_entropy_drop: Entropy(margpostMr[2]) - Entropy(margpostMr[0]),
cc.metallicity_entropy_drop: Entropy(margpostFeH[2]) - Entropy(margpostFeH[0]),
cc.extinction_r_entropy_drop: Entropy(margpostAr[2]) - Entropy(margpostAr[0]),
}