Compute Bayes estimates

Compute Bayes estimates#

[1]:
import os
default_n_threads = 1
os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}"

# Disable GPU memory pre-allocation
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

import jax
import lsdb
import nested_pandas as npd
import numpy as np
import pandas as pd

from dask import delayed
from dask.distributed import Client, get_worker
from photod.bayes import makeBayesEstimates3d, getEstimatesMeta
from photod.locus import LSSTsimsLocus, subsampleLocusData, get3DmodelList
from photod.parameters import GlobalParams
from photod.priors import initializePriorGrid
[2]:
s82StripeUrl = "/mnt/beegfs/scratch/data/S82_standards/S82_hats/S82_hats_fixed"
s82StripeCatalog = lsdb.read_hats(s82StripeUrl)
s82StripeCatalog
[2]:
lsdb Catalog S82_fixed:
CALIBSTARS ra dec RArms Decrms Ntot Ar uNobs umag ummu uErr umrms umchi2 gNobs gmag gmmu gErr gmrms gmchi2 rNobs rmag rmmu rErr rmrms rmchi2 iNobs imag immu iErr imrms imchi2 zNobs zmag zmmu zErr zmrms zmchi2 Norder Dir Npix Mr FeH MrEst MrEstUnc FeHEst ug gr gi ri iz ugErr grErr giErr riErr izErr glon glat
npartitions=7
Order: 4, Pixel: 0 string[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] int64[pyarrow] double[pyarrow] int64[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] int64[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] int64[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] int64[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] int64[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] int8[pyarrow] int64[pyarrow] int64[pyarrow] int64[pyarrow] int64[pyarrow] int64[pyarrow] int64[pyarrow] int64[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow] double[pyarrow]
Order: 4, Pixel: 768 ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Order: 4, Pixel: 2303 ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Order: 4, Pixel: 3071 ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
The catalog has been loaded lazily, meaning no data has been read, only the catalog schema
[3]:
priorMapUrl = "/mnt/beegfs/scratch/data/priors/hats/s82_priors"
priorMapCatalog = lsdb.read_hats(priorMapUrl)
priorMapCatalog
[3]:
lsdb Catalog s82_priors:
rmag kde xGrid yGrid Norder Dir Npix
npartitions=207
Order: 5, Pixel: 0 double[pyarrow] binary[pyarrow] binary[pyarrow] binary[pyarrow] uint8[pyarrow] uint64[pyarrow] uint64[pyarrow]
Order: 5, Pixel: 1 ... ... ... ... ... ... ...
... ... ... ... ... ... ... ...
Order: 5, Pixel: 12286 ... ... ... ... ... ... ...
Order: 5, Pixel: 12287 ... ... ... ... ... ... ...
The catalog has been loaded lazily, meaning no data has been read, only the catalog schema
[4]:
locusPath = "../../data/MSandRGBcolors_v1.3.txt"
fitColors = ("ug", "gr", "ri", "iz")
LSSTlocus = LSSTsimsLocus(fixForStripe82=False, datafile=locusPath)
OKlocus = LSSTlocus[(LSSTlocus["gi"] > 0.2) & (LSSTlocus["gi"] < 3.55)]
locusData = subsampleLocusData(OKlocus, kMr=1, kFeH=1)
ArGridList, locus3DList = get3DmodelList(locusData, fitColors)
globalParams = GlobalParams(fitColors, locusData, ArGridList, locus3DList)
subsampled locus 2D grid in FeH and Mr from 51 1559 to: 51 1559
[5]:
def mergingFunction(
    partition,
    mapPartition,
    partitionPixel,
    mapPixel,
    globalParams,
    workerDict,
    batchSize=10,
    **kwargs,
):
    """Function used by lsdb `merge_map`"""
    priorGrid = initializePriorGrid(mapPartition, globalParams)
    gpuDevice = jax.devices()[workerDict[get_worker().id]]
    with jax.default_device(gpuDevice):
        priorGrid = jax.numpy.array(list(priorGrid.values()))
        estimatesDf, _ = makeBayesEstimates3d(partition, priorGrid, globalParams, batchSize=batchSize)
    return npd.NestedFrame(estimatesDf)

[6]:
def getWorkerToGpuMapping(nWorkers):
    """Create a mapping between each worker and a GPU"""
    result = s82StripeCatalog._ddf.partitions[:nWorkers].map_partitions(
        lambda _: pd.DataFrame.from_dict({"workers":[get_worker().id]}), meta={"workers": object}).compute()
    workerIds = np.unique(result["workers"].to_numpy())
    return {id: i for i, id in enumerate(workerIds)}
[7]:
nWorkers = 4

# Took ~4 minutes to run
with Client(n_workers=nWorkers) as client:
    future = client.scatter(globalParams)
    workerToGpuMapping = getWorkerToGpuMapping(nWorkers)
    mergeLazy = s82StripeCatalog.merge_map(
        priorMapCatalog,
        mergingFunction,
        globalParams=future,
        workerDict=workerToGpuMapping,
        meta=getEstimatesMeta(),
    )
    mergeResult = mergeLazy.compute()

mergeResult
2025-01-29 22:52:26,290 - distributed.nanny - WARNING - Worker process still alive after 4.0 seconds, killing
2025-01-29 22:52:26,292 - distributed.nanny - WARNING - Worker process still alive after 4.0 seconds, killing
[7]:
glon glat chi2min Ar_quantile_hi Ar_quantile_lo Ar_quantile_median ArdS FeH_quantile_hi FeH_quantile_lo FeH_quantile_median FeHdS Mr_quantile_hi Mr_quantile_lo Mr_quantile_median MrdS Qr_quantile_hi Qr_quantile_lo Qr_quantile_median
_healpix_29
122002702160 176.940106 -48.855926 6.268703 0.386949 0.174107 0.280442 -132.473633 -0.227780 -0.964648 -0.523272 -15.111343 10.874006 10.172916 10.476210 -306.561066 11.128321 10.469028 10.755150
162211513082 176.914264 -48.879749 0.364965 0.414760 0.332011 0.373790 -204.202469 -0.091860 -0.608757 -0.324400 -26.431089 10.827707 10.317924 10.558826 -359.967987 11.195699 10.696110 10.932160
187874205331 176.875399 -48.898395 17.801376 0.281888 0.191349 0.237319 -195.978607 -0.465124 -0.754734 -0.608852 -33.365372 6.583451 6.234819 6.410890 -382.531738 6.786264 6.503552 6.646283
268254148314 176.88689 -48.857814 0.055334 0.493730 0.299565 0.414506 -142.391022 -2.050107 -2.418640 -2.218580 -36.739525 4.663245 3.311964 4.256591 -242.156769 5.037274 3.637865 4.721453
282956553349 176.959307 -48.834366 21.198009 0.412039 0.351518 0.381228 -231.246979 -0.800082 -0.984448 -0.895246 -31.138645 6.499413 6.285667 6.408841 -400.166687 6.857510 6.690443 6.790002
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
3458764488921378833 48.889173 -28.256075 0.503932 0.233284 0.190768 0.212024 -270.155884 -0.053726 -0.512604 -0.243493 -23.487144 10.355577 9.897295 10.096420 -322.765930 10.565968 10.109548 10.309708
3458764491323291543 48.891169 -28.255065 6.493702 0.156990 0.020727 0.081544 -173.114441 -0.047796 -0.520705 -0.239814 -30.629833 10.302315 9.813795 10.025238 -333.407837 10.382231 9.908035 10.110762
3458764494738379595 48.895862 -28.255315 9.436792 0.355219 0.168725 0.262869 -142.162201 -0.079470 -0.594446 -0.304783 -26.398663 10.727048 10.189111 10.439279 -314.661377 10.971685 10.466538 10.698907
3458764505128080304 48.915716 -28.267155 0.091048 0.368464 0.196369 0.278210 -148.013855 -1.812392 -2.172716 -2.011806 -30.696327 6.212793 5.642663 5.964199 -287.271667 6.414802 6.005089 6.242133
3458764508180429281 48.904565 -28.256011 44.671547 1.430663 1.336793 1.384181 -192.466461 0.000000 -0.153784 -0.045414 -50.178368 6.148608 5.930100 6.015741 -399.110870 7.513712 7.338460 7.396204

952400 rows × 18 columns