This notebook is part of a set of examples for teaching Bayesian inference methods and probabilistic programming, using the numpyro
library.
Imports
import io
import arviz as az
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.handlers
import numpyro.infer
import pandas as pd
import scipy.stats
from arviz.labels import MapLabeller
from IPython.display import display
from numpyro.infer.reparam import LocScaleReparam
Constants
FIGSIZE = (6.4 , 4.8 )
TEXTSIZE = 10
MARKER_STYLE = {"marker" : "o" , "ls" : "none" , "lw" : 2 , "capsize" : 4 }
INDICATOR_STYLE = {"c" : "gray" , "ls" : "dotted" }
LABELLER = MapLabeller({"sigma" : r"$\sigma$ (km/s)" , "v0" : r"$v_0$ (km/s)" })
RNG Setup
rng_key = jax.random.PRNGKey(1 )
The data consists of radial (line-of-sight) velocities measured for ten globular-cluster-like objects in the ultra-diffuse galaxy NGC1052–DF2. In this example, we will estimate a velocity dispersion from these measurements, taking into account the given measurement uncertainties. From this velocity dispersion, it is possible to estimate the total mass of the galaxy halo, and thus its dark matter fraction (by comparison with its independently estimated stellar mass).
Data
Data
csv = """
name,v,v_err
GC-39,14.728960068829865,7.046858966028704
GC-59,-3.8926505316836715,15.641403662859219
GC-71,2.0484049489377583,6.883925810744802
GC-73,10.576494786228665,3.2586631056780107
GC-77,1.1716309895635035,5.906326879041401
GC-85,-1.9077877208626857,5.3767941243687225
GC-91,-1.2121402228903977,9.69434349499684
GC-92,-14.32028284765569,6.761636322084257
GC-98,-39.335899515592,12.586586245681328
GC-101,-3.3753512069869345,13.523451888563756
"""
data = pd.read_csv(io.StringIO(csv))
Figure
def plot_data(ax= None ):
if ax is None :
_, ax = plt.subplots(figsize= FIGSIZE)
ax.errorbar(x= data.index, y= data["v" ], yerr= data["v_err" ], ** MARKER_STYLE)
ax.axhline(0 , ** INDICATOR_STYLE)
ax.set_xticks(data.index, labels= data["name" ], rotation= 45 )
ax.set_ylabel(r"$\Delta\,v$ (km/s)" )
ax.set_xlim(- 1 , 10 )
ax.set_ylim(- 60 , 60 )
return ax
plot_data()
plt.show()
0
GC-39
14.728960
7.046859
1
GC-59
-3.892651
15.641404
2
GC-71
2.048405
6.883926
3
GC-73
10.576495
3.258663
4
GC-77
1.171631
5.906327
5
GC-85
-1.907788
5.376794
6
GC-91
-1.212140
9.694343
7
GC-92
-14.320283
6.761636
8
GC-98
-39.335900
12.586586
9
GC-101
-3.375351
13.523452
Model
def model(df):
v_meas = numpyro.param("v_meas" , df["v" ].values)
v_err = numpyro.param("v_err" , df["v_err" ].values)
v0 = numpyro.sample("v0" , dist.Uniform(- 50 , 50 ))
sigma = jnp.exp(numpyro.sample("log_sigma" , dist.Uniform(np.log(0.5 ), np.log(50 ))))
with numpyro.plate("data" , len (df)):
with numpyro.handlers.reparam(config= {"v" : LocScaleReparam(0 )}):
v = numpyro.sample("v" , dist.Normal(v0, sigma))
numpyro.sample("v_obs" , dist.Normal(v, v_err), obs= v_meas)
Figure
numpyro.render_model(
model= model,
model_args= (data,),
render_params= True ,
render_distributions= True ,
)
Sampling
sampler = numpyro.infer.MCMC(
sampler= numpyro.infer.NUTS(model),
num_warmup= 2000 ,
num_samples= 10000 ,
num_chains= 4 ,
progress_bar= False ,
)
rng_key, _rng_key = jax.random.split(rng_key)
sampler.run(_rng_key, data)
inf_data = az.from_numpyro(sampler)
inf_data.posterior["sigma" ] = np.exp(inf_data.posterior["log_sigma" ])
Figure
def plot_trace(ax= None ):
return az.plot_trace(
data= inf_data,
figsize= (FIGSIZE[0 ], 3 * 2 ),
var_names= ["~v_decentered" , "~sigma" ],
axes= ax,
)
plot_trace()
plt.tight_layout()
plt.show()
Table
az.summary(inf_data, kind= "all" , var_names= ["~v_decentered" , "~sigma" ])
log_sigma
2.158
0.581
1.095
3.161
0.006
0.004
9418.0
11047.0
1.0
v[0]
8.597
6.266
-2.779
20.338
0.034
0.024
33127.0
32598.0
1.0
v[1]
-1.687
8.877
-19.277
14.680
0.044
0.042
40434.0
30650.0
1.0
v[2]
1.087
5.516
-9.506
11.499
0.023
0.024
57722.0
34346.0
1.0
v[3]
8.840
3.351
2.587
15.143
0.020
0.014
28522.0
23647.0
1.0
v[4]
0.710
4.983
-8.748
10.205
0.020
0.023
62013.0
34231.0
1.0
v[5]
-1.376
4.739
-10.285
7.407
0.021
0.020
51347.0
34295.0
1.0
v[6]
-0.907
6.929
-14.456
11.955
0.031
0.032
48965.0
33178.0
1.0
v[7]
-8.860
6.563
-20.812
3.454
0.043
0.030
22420.0
17931.0
1.0
v[8]
-14.667
11.318
-34.934
5.646
0.091
0.065
14527.0
18427.0
1.0
v[9]
-1.548
8.397
-17.877
14.303
0.042
0.041
39891.0
29849.0
1.0
v0
-1.038
4.532
-9.701
7.217
0.041
0.034
13092.0
13036.0
1.0
Results
Figure
def plot_pair(ax= None ):
ax = az.plot_pair(
data= inf_data,
figsize= FIGSIZE,
var_names= ["sigma" , "v0" ],
labeller= LABELLER,
kind= ["hexbin" , "kde" ],
textsize= TEXTSIZE,
ax= ax,
)
ax.axhline(0 , ** INDICATOR_STYLE)
ax.set_xlim(0 , 35 )
ax.set_ylim(- 15 , 15 )
return ax
plot_pair()
plt.show()
Figure
def plot_posterior(ax= None ):
if ax is None :
_, ax = plt.subplots(figsize= FIGSIZE)
sigma = az.extract(inf_data, var_names= ["sigma" ])
ax.hist(sigma, bins= np.arange(0 , 36 , 1 ), density= True )
for x, y, p in [
(10 , 0.1 , scipy.stats.percentileofscore(sigma, 10 )),
(np.percentile(sigma, 90 ), 0.04 , 90 ),
]:
ax.axvline(x, ** INDICATOR_STYLE)
ax.text(x + 0.5 , y + 0.005 , f" { p:.1f} %" , c= INDICATOR_STYLE["c" ])
ax.text(x + 0.5 , y, f" { x:.1f} km/s" , c= INDICATOR_STYLE["c" ])
ax.set_xlabel(r"$\sigma$ (km/s)" )
ax.set_ylabel("Probability Density" )
ax.set_xlim(0 , 35 )
ax.set_ylim(0 , 0.12 )
return ax
plot_posterior()
plt.show()
Figure 5: The posterior distribution of the velocity dispersion \(\sigma\) .
Figure
def plot_forest(ax= None ):
_ax = az.plot_forest(
data= inf_data,
figsize= FIGSIZE,
var_names= ["v" ],
kind= "ridgeplot" ,
hdi_prob= 0.99 ,
ridgeplot_overlap= 0.6 ,
ridgeplot_alpha= 0 ,
combined= True ,
colors= "black" ,
textsize= TEXTSIZE,
ax= ax,
)
ax = _ax[0 ]
dy = 0.825
y = np.arange(len (data)) * dy
y = y[::- 1 ]
ax.errorbar(x= data["v" ], y= y, xerr= data["v_err" ], ** MARKER_STYLE)
ax.axvline(0 , ** INDICATOR_STYLE)
ax.set_yticks(y, labels= data["name" ])
ax.set_xlabel(r"$\Delta\,v$ (km/s)" )
ax.set_xlim(- 60 , 60 )
ax.set_ylim(- dy / 2 , np.max (y) + dy)
return ax
plot_forest()
plt.show()
Watermark
Python implementation: CPython
Python version : 3.11.7
IPython version : 8.20.0
jax : 0.4.23
pandas : 2.1.4
matplotlib: 3.8.2
scipy : 1.11.4
numpyro : 0.13.2
arviz : 0.17.0
numpy : 1.26.3