Compare with Gaia distances

Compare with Gaia distances#

[1]:
import os

default_n_threads = 1
os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}"
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

from pathlib import Path

import jax
import lsdb
import matplotlib.pyplot as plt
import nested_pandas as npd
import numpy as np
import pandas as pd

from astropy.coordinates import SkyCoord
from dask import delayed
from dask.distributed import Client, get_worker
from dustmaps import sfd
from scipy.interpolate import griddata
from photod.bayes import makeBayesEstimates3d
from photod.locus import LSSTsimsLocus, subsampleLocusData, get3DmodelList
from photod.parameters import GlobalParams

dustmaps_cache = '/mnt/beegfs/scratch/data/dustmaps'
Path(dustmaps_cache).mkdir(exist_ok=True, parents=True)
import dustmaps.config; dustmaps.config.config['data_dir'] = dustmaps_cache
sfd.fetch()
Downloading SFD data file to /mnt/beegfs/scratch/data/dustmaps/sfd/SFD_dust_4096_ngp.fits
Checking existing file to see if MD5 sum matches ...
File exists. Not overwriting.
Downloading SFD data file to /mnt/beegfs/scratch/data/dustmaps/sfd/SFD_dust_4096_sgp.fits
Checking existing file to see if MD5 sum matches ...
File exists. Not overwriting.
[2]:
prior_map_url = "/mnt/beegfs/scratch/data/priors/hats/s82_priors"
prior_map_catalog = lsdb.read_hats(prior_map_url)
prior_map_catalog
[2]:
lsdb Catalog s82_priors:
rmag kde xGrid yGrid Norder Dir Npix
npartitions=207
Order: 5, Pixel: 0 double[pyarrow] binary[pyarrow] binary[pyarrow] binary[pyarrow] uint8[pyarrow] uint64[pyarrow] uint64[pyarrow]
Order: 5, Pixel: 1 ... ... ... ... ... ... ...
... ... ... ... ... ... ... ...
Order: 5, Pixel: 12286 ... ... ... ... ... ... ...
Order: 5, Pixel: 12287 ... ... ... ... ... ... ...
The catalog has been loaded lazily, meaning no data has been read, only the catalog schema
[3]:
def eq_to_gal(df):
    ra = df["ra"]
    dec = df["dec"]
    coord = SkyCoord(ra, dec, unit='deg')
    gal = coord.galactic
    return df.assign(glon=gal.l.deg, glat=gal.b.deg)

s82_stripe_catalog = lsdb.read_hats(
    "/mnt/beegfs/scratch/data/Gaia-SDSS/hats/Gaia-SDSS",
    search_filter=lsdb.BoxSearch(ra=[-52, 60], dec=[-1.266, 1.266]),
).map_partitions(
    lambda df: eq_to_gal(df.rename(
        columns={f"psfmag_{b}": f"{b}mag" for b in "ugriz"} | {f"psfmagerr_{b}": f"{b}magErr" for b in "ugriz"}
    ).eval(
        """
        ug = umag - gmag
        gr = gmag - rmag
        ri = rmag - imag
        iz = imag - zmag
        ugErr = sqrt(umagErr*umagErr, gmagErr*gmagErr)
        grErr = sqrt(gmagErr*gmagErr, gmagErr*rmagErr)
        riErr = sqrt(rmagErr*rmagErr, imagErr*imagErr)
        izErr = sqrt(imagErr*imagErr, zmagErr*zmagErr)
        parallax_over_error = parallax / parallax_error
        """,
    ))
)
s82_stripe_catalog
[3]:
lsdb Catalog Gaia-SDSS:
random_index ra dec phot_g_mean_mag parallax parallax_error source_id r_med_geo r_lo_geo r_hi_geo r_med_photogeo r_lo_photogeo r_hi_photogeo flag objid type umag gmag rmag imag zmag umagErr gmagErr rmagErr imagErr zmagErr Norder Dir Npix ug gr ri iz ugErr grErr riErr izErr parallax_over_error glon glat
npartitions=13
Order: 1, Pixel: 0 int64[pyarrow] double[pyarrow] double[pyarrow] float[pyarrow] double[pyarrow] float[pyarrow] int64[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] string[pyarrow] int64[pyarrow] int16[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] uint8[pyarrow] uint64[pyarrow] uint64[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] float[pyarrow] double[pyarrow] float64 float64
Order: 2, Pixel: 48 ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Order: 0, Pixel: 8 ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Order: 0, Pixel: 11 ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
The catalog has been loaded lazily, meaning no data has been read, only the catalog schema
[4]:
def merging_function(partition, map_partition, partition_pixel, map_pixel, globalParams, worker_dict, *kwargs):
    priorGrid = {}
    for rind, r in enumerate(np.sort(map_partition["rmag"].to_numpy())):
        # interpolate prior map onto locus Mr-FeH grid
        Z = map_partition[map_partition["rmag"] == r]
        Zval = np.frombuffer(Z.iloc[0]["kde"], dtype=np.float64).reshape((96, 36))
        X = np.frombuffer(Z.iloc[0]["xGrid"], dtype=np.float64).reshape((96, 36))
        Y = np.frombuffer(Z.iloc[0]["yGrid"], dtype=np.float64).reshape((96, 36))
        points = np.array((X.flatten(), Y.flatten())).T
        values = Zval.flatten()
        # actual (linear) interpolation
        priorGrid[rind] = griddata(
            points, values, (globalParams.locusData["FeH"], globalParams.locusData[globalParams.MrColumn]), method="linear", fill_value=0
        )
    gpu_device = jax.devices()[worker_dict[get_worker().id]]
    with jax.default_device(gpu_device):
        priorGrid = jax.numpy.array(list(priorGrid.values()))
        estimatesDf, _ = makeBayesEstimates3d(partition, priorGrid, globalParams, batchSize=100)
    # Append ra and dec to be able to later crossmatch
    return pd.concat([partition[["ra", "dec", "r_med_geo", "r_med_photogeo", "parallax_over_error", "rmag", "umag", "gr",]], npd.NestedFrame(estimatesDf)], axis=1)
[5]:
locus_path = "../../data/MSandRGBcolors_v1.3.txt"
fitColors = ("ug", "gr", "ri", "iz")
LSSTlocus = LSSTsimsLocus(fixForStripe82=False, datafile=locus_path)
OKlocus = LSSTlocus[(LSSTlocus["gi"] > 0.2) & (LSSTlocus["gi"] < 3.55)]
locusData = subsampleLocusData(OKlocus, kMr=10, kFeH=2)
ArGridList, locus3DList = get3DmodelList(locusData, fitColors)
globalParams = GlobalParams(fitColors, locusData, ArGridList, locus3DList)
subsampled locus 2D grid in FeH and Mr from 51 1559 to: 25 155
[6]:
quantile_cols = [f"{statisticsName}_quantile_{quantile}" for statisticsName in ["Mr","FeH","Ar","Qr"] for quantile in ["lo","median","hi"]]
estimate_cols = sorted([*quantile_cols,"MrdS","FeHdS","ArdS"])
col_names = ["ra","dec","r_med_geo","r_med_photogeo","parallax_over_error","rmag","umag","gr","glon","glat","chi2min",*estimate_cols]
meta = npd.NestedFrame.from_dict({ col: pd.Series([], dtype=np.float32) for col in col_names })
meta.index.name = "_healpix_29"
meta
[6]:
ra dec r_med_geo r_med_photogeo parallax_over_error rmag umag gr glon glat ... FeH_quantile_lo FeH_quantile_median FeHdS Mr_quantile_hi Mr_quantile_lo Mr_quantile_median MrdS Qr_quantile_hi Qr_quantile_lo Qr_quantile_median
_healpix_29

0 rows × 26 columns

[7]:
def get_worker_dict():
    res = s82_stripe_catalog._ddf.partitions[0:5].map_partitions(lambda df: pd.DataFrame.from_dict({"workers":[get_worker().id]}), meta={"workers": object}).compute()
    worker_ids = np.unique(res["workers"].to_numpy())
    worker_dict = {id: i for i, id in enumerate(worker_ids)}
    print(worker_dict)
    return worker_dict
[8]:
with Client(n_workers=4) as client:
    worker_dict = get_worker_dict()
    print(worker_dict)
    delayed_global_params = delayed(globalParams)
    merge_lazy = s82_stripe_catalog.merge_map(prior_map_catalog, merging_function, globalParams=delayed_global_params, worker_dict=worker_dict, meta=meta)
    xmatch_result = merge_lazy.compute()
xmatch_result
/home/kmalanch/.virtualenvs/photoD/lib/python3.10/site-packages/distributed/node.py:187: UserWarning: Port 8787 is already in use.
Perhaps you already have a cluster running?
Hosting the HTTP server on port 44785 instead
  warnings.warn(
{'Worker-633b7d75-b798-4898-b317-ce0af5f1bd39': 0, 'Worker-84960edd-d4c6-4095-904d-67612e92a0cb': 1, 'Worker-9feffa31-f12d-4219-871c-5848d963f3b0': 2, 'Worker-d47fdb7f-2045-4178-9340-b77b95320622': 3}
{'Worker-633b7d75-b798-4898-b317-ce0af5f1bd39': 0, 'Worker-84960edd-d4c6-4095-904d-67612e92a0cb': 1, 'Worker-9feffa31-f12d-4219-871c-5848d963f3b0': 2, 'Worker-d47fdb7f-2045-4178-9340-b77b95320622': 3}
/home/kmalanch/.virtualenvs/photoD/lib/python3.10/site-packages/distributed/client.py:3371: UserWarning: Sending large graph of size 78.95 MiB.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
  warnings.warn(
[8]:
ra dec r_med_geo r_med_photogeo parallax_over_error rmag umag gr glon glat ... FeH_quantile_lo FeH_quantile_median FeHdS Mr_quantile_hi Mr_quantile_lo Mr_quantile_median MrdS Qr_quantile_hi Qr_quantile_lo Qr_quantile_median
_healpix_29
29153808338 45.004978 0.01988 314.098663 313.454712 140.47131 14.1454 17.21501 0.83732 176.944767 -48.885267 ... NaN NaN 0.291286 NaN NaN NaN -27.945087 NaN NaN NaN
29640498453 45.00432 0.021048 306.347778 310.828705 26.857705 18.101839 21.903919 1.52433 176.942794 -48.884931 ... -0.590813 -0.278518 -11.072062 10.594156 10.078468 10.290326 -33.299908 10.864006 10.387707 10.568277
282950616891 45.048282 0.048254 614.936218 616.651428 45.99728 14.971 17.75176 0.79687 176.959371 -48.834395 ... -0.533224 -0.289617 -5.525905 5.296045 -0.168435 0.333894 -18.281178 5.988453 0.033629 0.588651
425704743710 45.02362 0.068419 814.872498 811.626465 18.006064 16.55686 20.05662 1.0364 176.911311 -48.838160 ... -0.392313 -0.203972 -14.542096 6.844335 6.338311 6.570840 -30.928381 6.997999 6.658961 6.791443
643735774625 44.993271 0.076334 1476.288086 1521.849976 13.096099 16.204041 18.44335 0.58482 176.870634 -48.854504 ... -0.442113 -0.229036 -10.431520 5.507360 4.881606 5.149137 -26.814789 5.637996 5.146863 5.345767
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
3458764334488336214 315.027158 -0.037642 1023.075745 1106.234863 6.42036 17.969391 21.889391 1.393219 48.897672 -28.302597 ... -0.517343 -0.263206 -11.556311 8.095885 7.439503 7.790954 -23.818600 8.326488 7.885603 8.076653
3458764335048204908 315.029101 -0.038178 595.732178 593.664551 87.681984 13.2611 15.06771 0.43824 48.898293 -28.304536 ... NaN NaN 0.559260 NaN NaN NaN -27.933327 NaN NaN NaN
3458764488918291245 314.983418 -0.020028 390.564972 400.299591 16.613003 18.512329 22.417471 1.416811 48.889083 -28.256054 ... -0.485831 -0.235875 -14.399448 10.643343 10.224490 10.392323 -33.054119 10.704497 10.319132 10.462954
3458764504882974191 315.003638 -0.006759 464.123688 462.984772 9.646154 19.184 24.736589 1.520161 48.913855 -28.266487 ... -0.458691 -0.214811 -16.543257 11.052170 10.507268 10.759591 -28.616232 11.276964 10.896560 11.042253
3458764508179655431 314.990401 -0.008369 836.059204 836.891296 9.183089 17.697729 21.815559 1.331572 48.904511 -28.255994 ... -0.489374 -0.234240 -11.432226 8.678715 8.240970 8.420457 -29.453726 8.695536 8.291556 8.419674

375831 rows × 26 columns

[9]:
def pc_to_distmod(d):
    return 5.0 * np.log10(d / 10.0)

plot_samples = None
poe_window = 1000

df = xmatch_result.copy()
coord = SkyCoord(df["ra"], df["dec"], unit="deg")
# Coeeficient is for SDSS r-band, RV=3.1
# Taken from table 6, Shlafly & Finkbeiner 2011
# https://ui.adsabs.harvard.edu/abs/2011ApJ...737..103S/abstract
df["sfd_Ar"] = 2.285 * sfd.SFDQuery()(coord)
df["Mr_gaia_geo"] = df["rmag"] - pc_to_distmod(df["r_med_geo"]) - df["sfd_Ar"]
df["Mr_gaia_photogeo"] = df["rmag"] - pc_to_distmod(df["r_med_photogeo"]) - df["sfd_Ar"]
df = df.sort_values("parallax_over_error").query(
    "Mr_gaia_geo > 4"
    " and Mr_gaia_photogeo > 4"
    " and umag < 21"
    " and 0.2 < gr < 0.6"
    " and parallax_over_error >= 10"
)
print(df.shape)

poe = df["parallax_over_error"].to_numpy()

# obs = df["rmag"].to_numpy() - df["Qr_quantile_median"].to_numpy()
# mod = pc_to_distmod(df["r_med_geo_gaia_dist"]).to_numpy()
obs = df["Mr_quantile_median"].to_numpy()
mod = df["Mr_gaia_geo"].to_numpy()

residuals = mod - obs

poe_grid = poe[::poe_window][:-1]
medians = np.nanmedian(
    residuals[:len(residuals) // poe_window * poe_window].reshape(-1, poe_window),
    axis=1,
)
mean = np.nanmean(
    residuals[:len(residuals) // poe_window * poe_window].reshape(-1, poe_window),
    axis=1,
)
std = np.nanstd(
    residuals[:len(residuals) // poe_window * poe_window].reshape(-1, poe_window),
    axis=1,
)

idx_samples = None if plot_samples is None else np.random.default_rng(0).choice(len(residuals), plot_samples, replace=False)
plt.scatter(poe[idx_samples], residuals[idx_samples], marker='o', s=3, alpha=0.2)

plt.plot(poe_grid, medians, '-', color='green', label='median')
plt.plot(poe_grid, mean, '-', color='red', label='mean')
plt.plot(poe_grid, np.stack([mean - std, mean+std], axis=-1), '--', color='red', label=r'mean$\pm$std')

plt.xlim(0, 100)
plt.ylim(-4, 4)
plt.xlabel("Gaia DR3 parallax / error")
plt.ylabel("Bailer-Jones+20 geo $-$ photoD, mag")
plt.grid()
plt.legend()
plt.savefig("MYPLOT.png")
(43435, 29)
../../_images/pre_executed_validation_validate_with_gaia_dist_9_1.png
[10]:
np.mean(df["Mr_quantile_median"] + df["Ar_quantile_median"] - df["Qr_quantile_median"])
[10]:
np.float32(-0.011516952)