Source code for photod.plotting

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm

from photod.stats import getMargDistr, getMargDistr3D, getStats


[docs] def showMargPosteriors3D( x1d1, margp1, xLab1, yLab1, x1d2, margp2, xLab2, yLab2, x1d3, margp3, xLab3, yLab3, trueX1, trueX2, trueX3 ): fig, axs = plt.subplots(1, 3, figsize=(12.7, 4)) fig.subplots_adjust(wspace=0.25, left=0.1, right=0.95, bottom=0.12, top=0.95) # plot axs[0].plot(x1d1, margp1[2], "r", lw=3) axs[0].plot(x1d1, margp1[1], "g") axs[0].plot(x1d1, margp1[0], "b") axs[0].set(xlabel=xLab1, ylabel=yLab1) axs[0].plot([trueX1, trueX1], [0, 1.05 * np.max([margp1[0], margp1[2]])], "k", lw=1) meanX1, sigX1 = getStats(x1d1, margp1[2]) axs[0].plot([meanX1, meanX1], [0, 1.05 * np.max([margp1[0], margp1[2]])], "--r") axs[1].plot(x1d2, margp2[2], "r", lw=3) axs[1].plot(x1d2, margp2[1], "g") axs[1].plot(x1d2, margp2[0], "b") axs[1].set(xlabel=xLab2, ylabel=yLab2) axs[1].plot([trueX2, trueX2], [0, 1.05 * np.max([margp2[0], margp2[2]])], "k", lw=1) meanX2, sigX2 = getStats(x1d2, margp2[2]) axs[1].plot([meanX2, meanX2], [0, 1.05 * np.max([margp2[0], margp2[2]])], "--r") axs[2].plot(x1d3, margp3[2], "r", lw=3) axs[2].plot(x1d3, margp3[1], "g") axs[2].plot(x1d3, margp3[0], "b") axs[2].set(xlabel=xLab3, ylabel=yLab3) axs[2].plot([trueX3, trueX3], [0, 1.05 * np.max([margp3[0], margp3[2]])], "k", lw=1) meanX3, sigX3 = getStats(x1d3, margp3[2]) axs[2].plot([meanX3, meanX3], [0, 1.05 * np.max([margp3[0], margp3[2]])], "--r") plt.savefig("plots/margPosteriors3D.png") plt.show()
[docs] def showCornerPlot3( postCube, Mr1d, FeH1d, Ar1d, md, xLab, yLab, x0=-99, y0=-99, z0=-99, logScale=False, cmap="Blues" ): def oneImage(ax, image, extent, title, showTrue, x0, y0, origin, logScale=True, cmap="Blues"): im = image / image.max() if logScale: cmap = ax.imshow( im.T, origin=origin, aspect="auto", extent=extent, cmap=cmap, norm=LogNorm(im.max() / 100, vmax=im.max()), ) else: cmap = ax.imshow(im.T, origin="upper", aspect="auto", extent=extent, cmap=cmap) ax.set_title(title) if showTrue: ax.scatter(x0, y0, s=150, c="red", alpha=0.3) ax.scatter(x0, y0, s=40, c="yellow", alpha=0.3) return cmap # unpack metadata xMin = md[0] # FeH xMax = md[1] yMin = md[3] # Mr yMax = md[4] zMin = 0 # Ar zMin = Ar1d[0] zMax = Ar1d[-1] #### make 3 marginal (summed) 2-D distributions and 3 1-D marginal distributions # grid steps dFeH = FeH1d[1] - FeH1d[0] dMr = Mr1d[1] - Mr1d[0] if Ar1d.size > 1: dAr = Ar1d[1] - Ar1d[0] else: dAr = 0.01 # 1-D marginal distributions margMr, margFeH, margAr = getMargDistr3D(postCube, dMr, dFeH, dAr) # 2-D marginal distributions # Mr vs. FeH im1 = np.sum(postCube, axis=(2)) # Ar vs. FeH im2 = np.sum(postCube, axis=(1)) # Ar vs. Mr im3 = np.sum(postCube, axis=(0)) showTrue = False if (x0 > -99) & (y0 > -99): showTrue = True ### plot fig, axs = plt.subplots(3, 3, figsize=(12, 12)) fig.subplots_adjust(wspace=0.25, left=0.1, right=0.95, bottom=0.12, top=0.95) # row 1: marginal FeH myExtent = [xMin, xMax, yMin, yMax] axs[0, 0].plot(FeH1d, margFeH, "r", lw=3) axs[0, 0].plot([x0, x0], [0, 1.1 * np.max(margFeH)], "--k", lw=1) axs[0, 0].set(xlabel="FeH", ylabel="p(FeH)") axs[0, 1].set_axis_off() axs[0, 2].set_axis_off() # row 2: im1 and marginal Mr myExtent = [xMin, xMax, yMin, yMax] cmap = oneImage(axs[1, 0], im1, myExtent, "", showTrue, x0, y0, origin="upper", logScale=logScale) axs[1, 0].set(xlabel="FeH", ylabel="Mr") axs[1, 1].plot(Mr1d, margMr, "r", lw=3) axs[1, 1].plot([y0, y0], [0, 1.1 * np.max(margMr)], "--k", lw=1) axs[1, 1].set(xlabel="Mr", ylabel="p(Mr)") axs[1, 2].set_axis_off() # row 3: im2, im3, and marginal Ar myExtent = [xMin, xMax, zMin, zMax] cmap = oneImage(axs[2, 0], im2, myExtent, "", showTrue, x0, z0, origin="lower", logScale=logScale) axs[2, 0].set(xlabel="FeH", ylabel="Ar") myExtent = [yMax, yMin, zMin, zMax] cmap = oneImage(axs[2, 1], im3, myExtent, "", showTrue, y0, z0, origin="lower", logScale=logScale) axs[2, 1].set(xlabel="Mr", ylabel="Ar") axs[2, 2].plot(Ar1d, margAr, "r", lw=3) axs[2, 2].plot([z0, z0], [0, 1.1 * np.max(margAr)], "--k", lw=1) axs[2, 2].set(xlabel="Ar", ylabel="p(Ar)") cax = fig.add_axes([0.84, 0.1, 0.1, 0.75]) cax.set_axis_off() # cb = fig.colorbar(cmap, ax=cax) # if (logScale): # cb.set_label("density on log scale") # else: # cb.set_label("density on linear scale") # for ax in axs.flat: # ax.set(xlabel=xLab, ylabel=yLab) # print('pero') plt.savefig("plots/cornerPlot3.png") plt.show()
[docs] def showQrCornerPlot(postCube, Mr1d, FeH1d, Ar1d, x0=-99, y0=-99, z0=-99, logScale=False, cmap="Blues"): def oneImage(ax, image, extent, title, showTrue, x0, y0, origin, logScale=True, cmap="Blues"): im = image / image.max() if logScale: cmap = ax.imshow( im.T, origin=origin, aspect="auto", extent=extent, cmap=cmap, norm=LogNorm(im.max() / 100, vmax=im.max()), ) else: cmap = ax.imshow(im.T, origin="upper", aspect="auto", extent=extent, cmap=cmap) ax.set_title(title) if showTrue: ax.scatter(x0, y0, s=150, c="red", alpha=0.3) ax.scatter(x0, y0, s=40, c="yellow", alpha=0.3) return cmap # 2-D distribution in the Qr vs. FeH plane Qmap, Qr1d = getQmap(postCube, FeH1d, Mr1d, Ar1d) # 1-D marginal distribution for Qr dFeH = FeH1d[1] - FeH1d[0] dQr = Qr1d[1] - Qr1d[0] margQr, margFeH = getMargDistr(Qmap, dFeH, dQr) # map plotting limits xMin = np.min(FeH1d) xMax = np.max(FeH1d) yMin = np.min(Qr1d) yMax = np.max(Qr1d) showTrue = False if (x0 > -99) & (y0 > -99): showTrue = True ### plot fig, axs = plt.subplots(1, 3, figsize=(10, 3)) fig.subplots_adjust(wspace=0.25, left=0.1, right=0.95, bottom=0.12, top=0.95) myExtent = [xMin, xMax, yMax, yMin] cmap = oneImage(axs[0], Qmap, myExtent, "", showTrue, x0, y0 + z0, origin="upper", logScale=logScale) axs[0].set(xlabel="FeH", ylabel="Qr = Mr + Ar") axs[1].plot(Qr1d, margQr, "r", lw=3) axs[1].plot([y0 + z0, y0 + z0], [0, 1.1 * np.max(margQr)], "--k", lw=1) axs[1].set(xlabel="Qr", ylabel="p(Qr)") axs[2].plot(FeH1d, margFeH, "r", lw=3) axs[2].plot([x0, x0], [0, 1.1 * np.max(margFeH)], "--k", lw=1) axs[2].set(xlabel="FeH", ylabel="p(FeH)") cax = fig.add_axes([0.84, 0.1, 0.1, 0.75]) cax.set_axis_off() # cb = fig.colorbar(cmap, ax=cax) # if (logScale): # cb.set_label("density on log scale") # else: # cb.set_label("density on linear scale") # for ax in axs.flat: # ax.set(xlabel=xLab, ylabel=yLab) # print('pero') plt.savefig("plots/QrCornerPlot.png") plt.show() return Qr1d, margQr
[docs] def show3Flat2Dmaps(Z1, Z2, Z3, md, xLab, yLab, x0=-99, y0=-99, logScale=False, minFac=1000, cmap="Blues"): # unpack metadata xMin = md[0] xMax = md[1] nXbin = md[2] yMin = md[3] yMax = md[4] nYbin = md[5] # set local variables and myExtent = [xMin, xMax, yMin, yMax] Xpts = nXbin.astype(int) Ypts = nYbin.astype(int) # reshape flattened input arrays to get "images" im1 = Z1.reshape((Xpts, Ypts)) im2 = Z2.reshape((Xpts, Ypts)) im3 = Z3.reshape((Xpts, Ypts)) print("pts:", Xpts, Ypts) showTrue = False if (x0 > -99) & (y0 > -99): showTrue = True def oneImage(ax, image, extent, minFactor, title, showTrue, x0, y0, logScale=True, cmap="Blues"): im = image / image.max() ImMin = im.max() / minFactor if logScale: cmap = ax.imshow( im.T, origin="upper", aspect="auto", extent=extent, cmap=cmap, norm=LogNorm(ImMin, vmax=im.max()), ) ax.set_title(title) else: cmap = ax.imshow(im.T, origin="upper", aspect="auto", extent=extent, cmap=cmap) ax.set_title(title) if showTrue: ax.scatter(x0, y0, s=150, c="red") ax.scatter(x0, y0, s=40, c="yellow") return cmap fig, axs = plt.subplots(1, 3, figsize=(14, 4)) # plot from matplotlib.colors import LogNorm cmap = oneImage(axs[0], im1, myExtent, minFac, "Prior", showTrue, x0, y0, logScale=logScale) fig.colorbar(cmap, ax=axs[0]) cmap = oneImage(axs[1], im2, myExtent, minFac, "Likelihood", showTrue, x0, y0, logScale=logScale) fig.colorbar(cmap, ax=axs[1]) cmap = oneImage(axs[2], im3, myExtent, minFac, "Posterior", showTrue, x0, y0, logScale=logScale) fig.colorbar(cmap, ax=axs[2]) cax = fig.add_axes([0.84, 0.1, 0.1, 0.75]) cax.set_axis_off() for ax in axs.flat: ax.set(xlabel=xLab, ylabel=yLab) plt.savefig("plots/bayesPanels.png") plt.show()
[docs] def getQmap(cube, FeH1d, Mr1d, Ar1d): # interpolate 3D cube(FeH, Mr, Ar) onto Qr=Mr+Ar vs. FeH 2D grid Qmap = np.zeros((len(FeH1d), len(Qr1d))) # Q grid, same size as Mr1d array Qr1d = np.linspace(np.min(Mr1d), np.max(Ar1d) + np.max(Mr1d), np.size(Mr1d)) # Compute possible Mr values for all (j, k) pairs Mr_values = Qr1d[:, None] - Ar1d # Find nearest index for Mr in Mr1d using searchsorted jk_indices = np.searchsorted(Mr1d, Mr_values) - 1 # Ensure indices are within bounds valid_mask = (jk_indices >= 0) & (jk_indices < len(Mr1d)) jk_indices = np.clip(jk_indices, 0, len(Mr1d) - 1) for i in range(len(FeH1d)): Ssum = np.zeros(len(Qr1d)) Ssum += np.where(valid_mask, cube[i, jk_indices, np.arange(len(Ar1d))], 0).sum(axis=1) Qmap[i, :] = Ssum return Qmap, Qr1d
[docs] def plotStar( star, margpostAr, margpostMr, margpostFeH, likeCube, priorCube, postCube, mdLocus, xLabel, yLabel, Mr1d, FeH1d, Ar1d, ): # for testing and illustration FeHStar = star["FeH"] MrStar = star["Mr"] ArStar = star["Ar"] indA = np.argmax(margpostAr[2]) show3Flat2Dmaps( priorCube[:, :, indA], likeCube[:, :, indA], postCube[:, :, indA], mdLocus, xLabel, yLabel, logScale=True, x0=FeHStar, y0=MrStar, ) showMargPosteriors3D( Mr1d, margpostMr, "Mr", "p(Mr)", FeH1d, margpostFeH, "FeH", "p(FeH)", Ar1d, margpostAr, "Ar", "p(Ar)", MrStar, FeHStar, ArStar, ) # these show marginal 2D and 1D distributions (aka "corner plot") showCornerPlot3( postCube, Mr1d, FeH1d, Ar1d, mdLocus, xLabel, yLabel, logScale=True, x0=FeHStar, y0=MrStar, z0=ArStar, ) # Qr vs. FeH posterior and marginal 1D distributions for Qr and FeH Qr1d, margpostQr = showQrCornerPlot( postCube, Mr1d, FeH1d, Ar1d, x0=FeHStar, y0=MrStar, z0=ArStar, logScale=True ) QrEst, QrEstUnc = getStats(Qr1d, margpostQr) return QrEst, QrEstUnc
[docs] def plotStars(starsData, bayesResults, *plottingArgs): """Create the plots for the specified stars.""" def getValueForStar(statDict, index): return {key: value[index] for key, value in statDict.items()} # Drop the _healpix_29 index stars = starsData.reset_index(drop=True) if len(stars) != len(bayesResults): raise ValueError("Stars data and results have different size") # Iterate over each star in the results. These results are arrays # of an element each because they were packed with JAX, and that is # why we need to get the first elements of these arrays. for i, result in enumerate(bayesResults): print(f"Plotting star {i}...") QrEst, QrEstUnc = plotStar( stars.iloc[i], getValueForStar(result.margpostAr, 0), getValueForStar(result.margpostMr, 0), getValueForStar(result.margpostFeH, 0), result.likeCube[0], result.priorCube[0], result.postCube[0], *plottingArgs, ) print(QrEst, QrEstUnc)