Source code for photod.stats
import jax.numpy as jnp
[docs]
def pnorm(pdf, dx):
return pdf / jnp.sum(pdf) / dx
[docs]
def getMargDistr(arr2d, dX, dY):
margX = jnp.sum(arr2d, axis=0)
margY = jnp.sum(arr2d, axis=1)
return pnorm(margX, dX), pnorm(margY, dY)
[docs]
def getMargDistr3D(arr3d, dX, dY, dZ):
margX = jnp.sum(arr3d, axis=(0, 2))
margY = jnp.sum(arr3d, axis=(1, 2))
margZ = jnp.sum(arr3d, axis=(0, 1))
return pnorm(margX, dX), pnorm(margY, dY), pnorm(margZ, dZ)
[docs]
def Entropy(p):
# Because we cannot filter non-concrete arrays, 1 because log is 0
pOK = jnp.where(p > 0, p, 1)
return -jnp.sum(pOK * jnp.log2(pOK))
[docs]
def getStats(x, pdf):
mean = jnp.sum(x * pdf) / jnp.sum(pdf)
V = jnp.sum((x - mean) ** 2 * pdf) / jnp.sum(pdf)
return mean, jnp.sqrt(V)
[docs]
def getPosteriorQuantiles(x, pdf):
cumsum = jnp.cumsum(pdf)
cdf = (cumsum - 0.5 * pdf) / cumsum[-1]
quantiles = jnp.array([0.14, 0.5, 0.86])
return jnp.interp(quantiles, cdf, x)
[docs]
def getQrQuantiles(postCube, QrGrid, QrIndices):
margPdfOverFeH = jnp.sum(postCube, axis=(0)).T
weightsQr = jnp.zeros_like(QrGrid)
weightsQr = weightsQr.at[QrIndices].add(margPdfOverFeH)
cumsum = jnp.cumsum(weightsQr)
cdf = (cumsum - 0.5 * weightsQr) / cumsum[-1]
quantiles = jnp.array([0.14, 0.5, 0.86])
QrQuantiles = jnp.interp(quantiles, cdf, QrGrid)
return QrQuantiles