Compute Bayes estimates#
[1]:
import os
default_n_threads = 1
os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}"
# Disable GPU memory pre-allocation
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
import jax
import lsdb
import nested_pandas as npd
import numpy as np
import pandas as pd
from dask import delayed
from dask.distributed import Client, get_worker
from photod.bayes import makeBayesEstimates3d, getEstimatesMeta
from photod.locus import LSSTsimsLocus, subsampleLocusData, get3DmodelList
from photod.parameters import GlobalParams
from photod.priors import initializePriorGrid
[2]:
s82StripeUrl = "/mnt/beegfs/scratch/data/S82_standards/S82_hats/S82_hats_fixed"
s82StripeCatalog = lsdb.read_hats(s82StripeUrl)
s82StripeCatalog
[2]:
lsdb Catalog S82_fixed:
| CALIBSTARS | ra | dec | RArms | Decrms | Ntot | Ar | uNobs | umag | ummu | uErr | umrms | umchi2 | gNobs | gmag | gmmu | gErr | gmrms | gmchi2 | rNobs | rmag | rmmu | rErr | rmrms | rmchi2 | iNobs | imag | immu | iErr | imrms | imchi2 | zNobs | zmag | zmmu | zErr | zmrms | zmchi2 | Norder | Dir | Npix | Mr | FeH | MrEst | MrEstUnc | FeHEst | ug | gr | gi | ri | iz | ugErr | grErr | giErr | riErr | izErr | glon | glat | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| npartitions=7 | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Order: 4, Pixel: 0 | string[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | int64[pyarrow] | double[pyarrow] | int64[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | int64[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | int64[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | int64[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | int64[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | int8[pyarrow] | int64[pyarrow] | int64[pyarrow] | int64[pyarrow] | int64[pyarrow] | int64[pyarrow] | int64[pyarrow] | int64[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] | double[pyarrow] |
| Order: 4, Pixel: 768 | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Order: 4, Pixel: 2303 | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Order: 4, Pixel: 3071 | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
The catalog has been loaded lazily, meaning no data has been read, only the catalog schema
[3]:
priorMapUrl = "/mnt/beegfs/scratch/data/priors/hats/s82_priors"
priorMapCatalog = lsdb.read_hats(priorMapUrl)
priorMapCatalog
[3]:
lsdb Catalog s82_priors:
| rmag | kde | xGrid | yGrid | Norder | Dir | Npix | |
|---|---|---|---|---|---|---|---|
| npartitions=207 | |||||||
| Order: 5, Pixel: 0 | double[pyarrow] | binary[pyarrow] | binary[pyarrow] | binary[pyarrow] | uint8[pyarrow] | uint64[pyarrow] | uint64[pyarrow] |
| Order: 5, Pixel: 1 | ... | ... | ... | ... | ... | ... | ... |
| ... | ... | ... | ... | ... | ... | ... | ... |
| Order: 5, Pixel: 12286 | ... | ... | ... | ... | ... | ... | ... |
| Order: 5, Pixel: 12287 | ... | ... | ... | ... | ... | ... | ... |
The catalog has been loaded lazily, meaning no data has been read, only the catalog schema
[4]:
locusPath = "../../data/MSandRGBcolors_v1.3.txt"
fitColors = ("ug", "gr", "ri", "iz")
LSSTlocus = LSSTsimsLocus(fixForStripe82=False, datafile=locusPath)
OKlocus = LSSTlocus[(LSSTlocus["gi"] > 0.2) & (LSSTlocus["gi"] < 3.55)]
locusData = subsampleLocusData(OKlocus, kMr=1, kFeH=1)
ArGridList, locus3DList = get3DmodelList(locusData, fitColors)
globalParams = GlobalParams(fitColors, locusData, ArGridList, locus3DList)
subsampled locus 2D grid in FeH and Mr from 51 1559 to: 51 1559
[5]:
def mergingFunction(
partition,
mapPartition,
partitionPixel,
mapPixel,
globalParams,
workerDict,
batchSize=10,
**kwargs,
):
"""Function used by lsdb `merge_map`"""
priorGrid = initializePriorGrid(mapPartition, globalParams)
gpuDevice = jax.devices()[workerDict[get_worker().id]]
with jax.default_device(gpuDevice):
priorGrid = jax.numpy.array(list(priorGrid.values()))
estimatesDf, _ = makeBayesEstimates3d(partition, priorGrid, globalParams, batchSize=batchSize)
return npd.NestedFrame(estimatesDf)
[6]:
def getWorkerToGpuMapping(nWorkers):
"""Create a mapping between each worker and a GPU"""
result = s82StripeCatalog._ddf.partitions[:nWorkers].map_partitions(
lambda _: pd.DataFrame.from_dict({"workers":[get_worker().id]}), meta={"workers": object}).compute()
workerIds = np.unique(result["workers"].to_numpy())
return {id: i for i, id in enumerate(workerIds)}
[7]:
nWorkers = 4
# Took ~4 minutes to run
with Client(n_workers=nWorkers) as client:
future = client.scatter(globalParams)
workerToGpuMapping = getWorkerToGpuMapping(nWorkers)
mergeLazy = s82StripeCatalog.merge_map(
priorMapCatalog,
mergingFunction,
globalParams=future,
workerDict=workerToGpuMapping,
meta=getEstimatesMeta(),
)
mergeResult = mergeLazy.compute()
mergeResult
2025-01-29 22:52:26,290 - distributed.nanny - WARNING - Worker process still alive after 4.0 seconds, killing
2025-01-29 22:52:26,292 - distributed.nanny - WARNING - Worker process still alive after 4.0 seconds, killing
[7]:
| glon | glat | chi2min | Ar_quantile_hi | Ar_quantile_lo | Ar_quantile_median | ArdS | FeH_quantile_hi | FeH_quantile_lo | FeH_quantile_median | FeHdS | Mr_quantile_hi | Mr_quantile_lo | Mr_quantile_median | MrdS | Qr_quantile_hi | Qr_quantile_lo | Qr_quantile_median | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| _healpix_29 | ||||||||||||||||||
| 122002702160 | 176.940106 | -48.855926 | 6.268703 | 0.386949 | 0.174107 | 0.280442 | -132.473633 | -0.227780 | -0.964648 | -0.523272 | -15.111343 | 10.874006 | 10.172916 | 10.476210 | -306.561066 | 11.128321 | 10.469028 | 10.755150 |
| 162211513082 | 176.914264 | -48.879749 | 0.364965 | 0.414760 | 0.332011 | 0.373790 | -204.202469 | -0.091860 | -0.608757 | -0.324400 | -26.431089 | 10.827707 | 10.317924 | 10.558826 | -359.967987 | 11.195699 | 10.696110 | 10.932160 |
| 187874205331 | 176.875399 | -48.898395 | 17.801376 | 0.281888 | 0.191349 | 0.237319 | -195.978607 | -0.465124 | -0.754734 | -0.608852 | -33.365372 | 6.583451 | 6.234819 | 6.410890 | -382.531738 | 6.786264 | 6.503552 | 6.646283 |
| 268254148314 | 176.88689 | -48.857814 | 0.055334 | 0.493730 | 0.299565 | 0.414506 | -142.391022 | -2.050107 | -2.418640 | -2.218580 | -36.739525 | 4.663245 | 3.311964 | 4.256591 | -242.156769 | 5.037274 | 3.637865 | 4.721453 |
| 282956553349 | 176.959307 | -48.834366 | 21.198009 | 0.412039 | 0.351518 | 0.381228 | -231.246979 | -0.800082 | -0.984448 | -0.895246 | -31.138645 | 6.499413 | 6.285667 | 6.408841 | -400.166687 | 6.857510 | 6.690443 | 6.790002 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 3458764488921378833 | 48.889173 | -28.256075 | 0.503932 | 0.233284 | 0.190768 | 0.212024 | -270.155884 | -0.053726 | -0.512604 | -0.243493 | -23.487144 | 10.355577 | 9.897295 | 10.096420 | -322.765930 | 10.565968 | 10.109548 | 10.309708 |
| 3458764491323291543 | 48.891169 | -28.255065 | 6.493702 | 0.156990 | 0.020727 | 0.081544 | -173.114441 | -0.047796 | -0.520705 | -0.239814 | -30.629833 | 10.302315 | 9.813795 | 10.025238 | -333.407837 | 10.382231 | 9.908035 | 10.110762 |
| 3458764494738379595 | 48.895862 | -28.255315 | 9.436792 | 0.355219 | 0.168725 | 0.262869 | -142.162201 | -0.079470 | -0.594446 | -0.304783 | -26.398663 | 10.727048 | 10.189111 | 10.439279 | -314.661377 | 10.971685 | 10.466538 | 10.698907 |
| 3458764505128080304 | 48.915716 | -28.267155 | 0.091048 | 0.368464 | 0.196369 | 0.278210 | -148.013855 | -1.812392 | -2.172716 | -2.011806 | -30.696327 | 6.212793 | 5.642663 | 5.964199 | -287.271667 | 6.414802 | 6.005089 | 6.242133 |
| 3458764508180429281 | 48.904565 | -28.256011 | 44.671547 | 1.430663 | 1.336793 | 1.384181 | -192.466461 | 0.000000 | -0.153784 | -0.045414 | -50.178368 | 6.148608 | 5.930100 | 6.015741 | -399.110870 | 7.513712 | 7.338460 | 7.396204 |
952400 rows × 18 columns