The Galaxy Case

Normal Distribution
Real World

This notebook is part of a set of examples for teaching Bayesian inference methods and probabilistic programming, using the numpyro library.

Imports
import io

import arviz as az
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.handlers
import numpyro.infer
import pandas as pd
import scipy.stats

from arviz.labels import MapLabeller
from IPython.display import display
from numpyro.infer.reparam import LocScaleReparam
Constants
FIGSIZE = (6.4, 4.8)
TEXTSIZE = 10

MARKER_STYLE = {"marker": "o", "ls": "none", "lw": 2, "capsize": 4}
INDICATOR_STYLE = {"c": "gray", "ls": "dotted"}

LABELLER = MapLabeller({"sigma": r"$\sigma$ (km/s)", "v0": r"$v_0$ (km/s)"})
RNG Setup
rng_key = jax.random.PRNGKey(1)

The data consists of radial (line-of-sight) velocities measured for ten globular-cluster-like objects in the ultra-diffuse galaxy NGC1052–DF2. In this example, we will estimate a velocity dispersion from these measurements, taking into account the given measurement uncertainties. From this velocity dispersion, it is possible to estimate the total mass of the galaxy halo, and thus its dark matter fraction (by comparison with its independently estimated stellar mass).

Data

Data
csv = """
name,v,v_err
GC-39,14.728960068829865,7.046858966028704
GC-59,-3.8926505316836715,15.641403662859219
GC-71,2.0484049489377583,6.883925810744802
GC-73,10.576494786228665,3.2586631056780107
GC-77,1.1716309895635035,5.906326879041401
GC-85,-1.9077877208626857,5.3767941243687225
GC-91,-1.2121402228903977,9.69434349499684
GC-92,-14.32028284765569,6.761636322084257
GC-98,-39.335899515592,12.586586245681328
GC-101,-3.3753512069869345,13.523451888563756
"""

data = pd.read_csv(io.StringIO(csv))
Figure
def plot_data(ax=None):
    if ax is None:
        _, ax = plt.subplots(figsize=FIGSIZE)

    ax.errorbar(x=data.index, y=data["v"], yerr=data["v_err"], **MARKER_STYLE)
    ax.axhline(0, **INDICATOR_STYLE)

    ax.set_xticks(data.index, labels=data["name"], rotation=45)
    ax.set_ylabel(r"$\Delta\,v$ (km/s)")
    ax.set_xlim(-1, 10)
    ax.set_ylim(-60, 60)
    return ax


plot_data()
plt.show()

Figure 1: The radial velocities of ten globular-cluster-like objects in NGC1052–DF2.
name v v_err
0 GC-39 14.728960 7.046859
1 GC-59 -3.892651 15.641404
2 GC-71 2.048405 6.883926
3 GC-73 10.576495 3.258663
4 GC-77 1.171631 5.906327
5 GC-85 -1.907788 5.376794
6 GC-91 -1.212140 9.694343
7 GC-92 -14.320283 6.761636
8 GC-98 -39.335900 12.586586
9 GC-101 -3.375351 13.523452

Model

def model(df):
    v_meas = numpyro.param("v_meas", df["v"].values)
    v_err = numpyro.param("v_err", df["v_err"].values)

    v0 = numpyro.sample("v0", dist.Uniform(-50, 50))
    sigma = jnp.exp(numpyro.sample("log_sigma", dist.Uniform(np.log(0.5), np.log(50))))

    with numpyro.plate("data", len(df)):
        with numpyro.handlers.reparam(config={"v": LocScaleReparam(0)}):
            v = numpyro.sample("v", dist.Normal(v0, sigma))
        numpyro.sample("v_obs", dist.Normal(v, v_err), obs=v_meas)
Figure
numpyro.render_model(
    model=model,
    model_args=(data,),
    render_params=True,
    render_distributions=True,
)

Figure 2: The model graph.

Sampling

sampler = numpyro.infer.MCMC(
    sampler=numpyro.infer.NUTS(model),
    num_warmup=2000,
    num_samples=10000,
    num_chains=4,
    progress_bar=False,
)
rng_key, _rng_key = jax.random.split(rng_key)
sampler.run(_rng_key, data)
inf_data = az.from_numpyro(sampler)
inf_data.posterior["sigma"] = np.exp(inf_data.posterior["log_sigma"])
Figure
def plot_trace(ax=None):
    return az.plot_trace(
        data=inf_data,
        figsize=(FIGSIZE[0], 3 * 2),
        var_names=["~v_decentered", "~sigma"],
        axes=ax,
    )


plot_trace()
plt.tight_layout()
plt.show()

Figure 3: The MCMC trace plot.
Table
az.summary(inf_data, kind="all", var_names=["~v_decentered", "~sigma"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
log_sigma 2.158 0.581 1.095 3.161 0.006 0.004 9418.0 11047.0 1.0
v[0] 8.597 6.266 -2.779 20.338 0.034 0.024 33127.0 32598.0 1.0
v[1] -1.687 8.877 -19.277 14.680 0.044 0.042 40434.0 30650.0 1.0
v[2] 1.087 5.516 -9.506 11.499 0.023 0.024 57722.0 34346.0 1.0
v[3] 8.840 3.351 2.587 15.143 0.020 0.014 28522.0 23647.0 1.0
v[4] 0.710 4.983 -8.748 10.205 0.020 0.023 62013.0 34231.0 1.0
v[5] -1.376 4.739 -10.285 7.407 0.021 0.020 51347.0 34295.0 1.0
v[6] -0.907 6.929 -14.456 11.955 0.031 0.032 48965.0 33178.0 1.0
v[7] -8.860 6.563 -20.812 3.454 0.043 0.030 22420.0 17931.0 1.0
v[8] -14.667 11.318 -34.934 5.646 0.091 0.065 14527.0 18427.0 1.0
v[9] -1.548 8.397 -17.877 14.303 0.042 0.041 39891.0 29849.0 1.0
v0 -1.038 4.532 -9.701 7.217 0.041 0.034 13092.0 13036.0 1.0

Results

Figure
def plot_pair(ax=None):
    ax = az.plot_pair(
        data=inf_data,
        figsize=FIGSIZE,
        var_names=["sigma", "v0"],
        labeller=LABELLER,
        kind=["hexbin", "kde"],
        textsize=TEXTSIZE,
        ax=ax,
    )

    ax.axhline(0, **INDICATOR_STYLE)

    ax.set_xlim(0, 35)
    ax.set_ylim(-15, 15)
    return ax


plot_pair()
plt.show()

Figure 4: The joint posterior distribution of the velocity dispersion \(\sigma\) and the velocity offset \(v_0\).
Figure
def plot_posterior(ax=None):
    if ax is None:
        _, ax = plt.subplots(figsize=FIGSIZE)

    sigma = az.extract(inf_data, var_names=["sigma"])
    ax.hist(sigma, bins=np.arange(0, 36, 1), density=True)

    for x, y, p in [
        (10, 0.1, scipy.stats.percentileofscore(sigma, 10)),
        (np.percentile(sigma, 90), 0.04, 90),
    ]:
        ax.axvline(x, **INDICATOR_STYLE)
        ax.text(x + 0.5, y + 0.005, f"{p:.1f} %", c=INDICATOR_STYLE["c"])
        ax.text(x + 0.5, y, f"{x:.1f} km/s", c=INDICATOR_STYLE["c"])

    ax.set_xlabel(r"$\sigma$ (km/s)")
    ax.set_ylabel("Probability Density")
    ax.set_xlim(0, 35)
    ax.set_ylim(0, 0.12)
    return ax


plot_posterior()
plt.show()

Figure 5: The posterior distribution of the velocity dispersion \(\sigma\).
Figure
def plot_forest(ax=None):
    _ax = az.plot_forest(
        data=inf_data,
        figsize=FIGSIZE,
        var_names=["v"],
        kind="ridgeplot",
        hdi_prob=0.99,
        ridgeplot_overlap=0.6,
        ridgeplot_alpha=0,
        combined=True,
        colors="black",
        textsize=TEXTSIZE,
        ax=ax,
    )

    ax = _ax[0]
    dy = 0.825
    y = np.arange(len(data)) * dy
    y = y[::-1]

    ax.errorbar(x=data["v"], y=y, xerr=data["v_err"], **MARKER_STYLE)
    ax.axvline(0, **INDICATOR_STYLE)

    ax.set_yticks(y, labels=data["name"])
    ax.set_xlabel(r"$\Delta\,v$ (km/s)")
    ax.set_xlim(-60, 60)
    ax.set_ylim(-dy / 2, np.max(y) + dy)
    return ax


plot_forest()
plt.show()

Figure 6: The predicted velocities (black), shown together with the observed values (blue).

Resources

Watermark

Python implementation: CPython
Python version       : 3.11.7
IPython version      : 8.20.0

jax       : 0.4.23
pandas    : 2.1.4
matplotlib: 3.8.2
scipy     : 1.11.4
numpyro   : 0.13.2
arviz     : 0.17.0
numpy     : 1.26.3