This notebook is part of a set of examples for teaching Bayesian inference methods and probabilistic programming, using the numpyro
library.
Imports
import io
import arviz as az
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.infer
import pandas as pd
from arviz.labels import MapLabeller
from IPython.display import display
Constants
FIGSIZE = (6.4 , 4.8 )
TEXTSIZE = 10
LABELLER = MapLabeller({"alpha" : r"$\alpha$" , "beta" : r"$\beta$" })
RNG Setup
np_rng = np.random.default_rng(1 )
rng_key = jax.random.PRNGKey(1 )
In this example, we will fit a line to some data points using three separate models that handle outliers in different ways, and compare them.
Data
Data
csv = """
x,y,y_err
-1.991,-1.265,0.2
-1.942,-1.718,0.2
-1.866,-0.539,0.2
-1.451,-1.468,0.2
-1.383,-0.04,0.2
-0.947,-1.069,0.2
-0.865,-0.935,0.2
0.135,-0.006,0.2
0.424,0.219,0.2
0.96,1.366,0.2
1.411,-0.392,0.2
1.603,1.777,0.2
1.675,1.857,0.2
1.777,1.986,0.2
1.828,1.59,0.2
"""
data = pd.read_csv(io.StringIO(csv))
Figure
def plot_data(scatter_kwargs= None , ax= None ):
if ax is None :
_, ax = plt.subplots(figsize= FIGSIZE)
scatter_kwargs = {"color" : "black" , ** (scatter_kwargs or {})}
data.plot.scatter(x= "x" , y= "y" , yerr= "y_err" , ** scatter_kwargs, ax= ax)
return ax
plot_data()
plt.show()
0
-1.991
-1.265
0.2
1
-1.942
-1.718
0.2
2
-1.866
-0.539
0.2
3
-1.451
-1.468
0.2
4
-1.383
-0.040
0.2
5
-0.947
-1.069
0.2
6
-0.865
-0.935
0.2
7
0.135
-0.006
0.2
8
0.424
0.219
0.2
9
0.960
1.366
0.2
10
1.411
-0.392
0.2
11
1.603
1.777
0.2
12
1.675
1.857
0.2
13
1.777
1.986
0.2
14
1.828
1.590
0.2
Models
Model A
Let’s start with a minimal linear model parametrized by a slope \(\beta\) and intercept \(\alpha\) , which does not take outliers into account:
\[y(x_i)=\alpha+\beta\,x_i\]
To be able to choose convenient minimally informative priors, let’s reparametrize this model and use the angle \(\theta\) between the line and the \(x\) -axis instead of the slope, as well as the perpendicular intercept \(\alpha_p\) instead of the \(y\) -axis intercept:
\[\theta=\arctan(\beta)\] \[\alpha_p=\alpha\cos(\theta)\]
def model_A(df):
x_meas = numpyro.param("x_meas" , df["x" ].values)
y_meas = numpyro.param("y_meas" , df["y" ].values)
y_err = numpyro.param("y_err" , df["y_err" ].values)
alpha_p = numpyro.sample("alpha_p" , dist.Uniform(- 0.5 , 0.5 ))
theta = numpyro.sample("theta" , dist.Uniform(- 0.5 * jnp.pi, 0.5 * jnp.pi))
alpha = numpyro.deterministic("alpha" , alpha_p / jnp.cos(theta))
beta = numpyro.deterministic("beta" , jnp.tan(theta))
with numpyro.plate("data" , len (df)):
y = numpyro.deterministic("y" , alpha + beta * x_meas)
numpyro.sample("y_obs" , dist.Normal(y, y_err), obs= y_meas)
Figure
numpyro.render_model(
model= model_A,
model_args= (data,),
render_params= True ,
render_distributions= True ,
)
Model B
To make the model more robust against outliers, let’s try to use the \(t\) distribution as the sampling distribution instead of the normal distribution. In comparison, the \(t\) distribution has heavier tails, the strength of which can be controlled by an additional free parameter \(\nu\) .
def model_B(df):
x_meas = numpyro.param("x_meas" , df["x" ].values)
y_meas = numpyro.param("y_meas" , df["y" ].values)
y_err = numpyro.param("y_err" , df["y_err" ].values)
alpha_p = numpyro.sample("alpha_p" , dist.Uniform(- 0.5 , 0.5 ))
theta = numpyro.sample("theta" , dist.Uniform(- 0.5 * jnp.pi, 0.5 * jnp.pi))
nu = numpyro.sample("nu" , dist.InverseGamma(1 , 1 ))
alpha = numpyro.deterministic("alpha" , alpha_p / jnp.cos(theta))
beta = numpyro.deterministic("beta" , jnp.tan(theta))
with numpyro.plate("data" , len (df)):
y = numpyro.deterministic("y" , alpha + beta * x_meas)
numpyro.sample("y_obs" , dist.StudentT(nu, y, y_err), obs= y_meas)
Figure
numpyro.render_model(
model= model_B,
model_args= (data,),
render_params= True ,
render_distributions= True ,
)
Model C
Let’s also try to extend model A into a two-component mixture model, where we keep model A as a “foreground” model, but add an explicit “background” model for outliers (a broad normal distribution).
def model_C(df):
x_meas = numpyro.param("x_meas" , df["x" ].values)
y_meas = numpyro.param("y_meas" , df["y" ].values)
y_err = numpyro.param("y_err" , df["y_err" ].values)
alpha_p = numpyro.sample("alpha_p" , dist.Uniform(- 0.5 , 0.5 ))
theta = numpyro.sample("theta" , dist.Uniform(- 0.5 * jnp.pi, 0.5 * jnp.pi))
alpha = numpyro.deterministic("alpha" , alpha_p / jnp.cos(theta))
beta = numpyro.deterministic("beta" , jnp.tan(theta))
bg_mean = numpyro.sample("bg_mean" , dist.Normal(0.0 , 1.0 ))
bg_sigma = numpyro.sample("bg_sigma" , dist.HalfNormal(3.0 ))
q = numpyro.sample("q" , dist.Uniform(0 , 1 ))
mixing_dist = dist.Categorical(probs= jnp.array([q, 1 - q]))
with numpyro.plate("data" , len (df)):
y = numpyro.deterministic("y" , alpha + beta * x_meas)
fg_dist = dist.Normal(y, y_err)
bg_dist = dist.Normal(bg_mean, jnp.sqrt(y_err** 2 + bg_sigma** 2 ))
mixture = dist.MixtureGeneral(mixing_dist, [bg_dist, fg_dist])
y_obs = numpyro.sample("y_obs" , mixture, obs= y_meas)
log_probs = mixture.component_log_probs(y_obs)
logsumexp = jax.nn.logsumexp(log_probs, axis=- 1 , keepdims= True )
numpyro.deterministic("log_p" , log_probs[:, 0 ] - logsumexp[:, 0 ])
Figure
numpyro.render_model(
model= model_C,
model_args= (data,),
render_params= True ,
render_distributions= True ,
)
Sampling
def sample_model(key, model):
sampler = numpyro.infer.MCMC(
sampler= numpyro.infer.NUTS(model),
num_warmup= 2000 ,
num_samples= 10000 ,
num_chains= 4 ,
progress_bar= False ,
)
key, _key = jax.random.split(key)
sampler.run(_key, data)
return key, sampler, az.from_numpyro(sampler)
Model A
rng_key, sampler_A, inf_data_A = sample_model(rng_key, model_A)
Figure
def plot_trace_A(ax= None ):
return az.plot_trace(
data= inf_data_A,
figsize= (FIGSIZE[0 ], 2 * 2 ),
var_names= ["~alpha" , "~beta" , "~y" ],
axes= ax,
)
plot_trace_A()
plt.tight_layout()
plt.show()
Table
az.summary(inf_data_A, kind= "all" , var_names= ["~y" ])
alpha
0.123
0.051
0.028
0.220
0.0
0.0
38063.0
27611.0
1.0
alpha_p
0.098
0.041
0.021
0.175
0.0
0.0
38019.0
26810.0
1.0
beta
0.757
0.036
0.690
0.824
0.0
0.0
38251.0
26859.0
1.0
theta
0.648
0.023
0.605
0.691
0.0
0.0
38251.0
26859.0
1.0
Model B
rng_key, sampler_B, inf_data_B = sample_model(rng_key, model_B)
Figure
def plot_trace_B(ax= None ):
return az.plot_trace(
data= inf_data_B,
figsize= (FIGSIZE[0 ], 3 * 2 ),
var_names= ["~alpha" , "~beta" , "~y" ],
axes= ax,
)
plot_trace_B()
plt.tight_layout()
plt.show()
Table
az.summary(inf_data_B, kind= "all" , var_names= ["~y" ])
alpha
0.053
0.110
-0.148
0.255
0.001
0.001
20228.0
14070.0
1.0
alpha_p
0.038
0.081
-0.110
0.184
0.001
0.001
19871.0
13157.0
1.0
beta
0.983
0.074
0.845
1.119
0.001
0.000
20588.0
19928.0
1.0
nu
1.138
0.448
0.438
1.949
0.003
0.002
26560.0
26607.0
1.0
theta
0.775
0.038
0.704
0.844
0.000
0.000
20588.0
19928.0
1.0
Model C
rng_key, sampler_C, inf_data_C = sample_model(rng_key, model_C)
inf_data_C.posterior["p" ] = np.exp(inf_data_C.posterior["log_p" ])
Figure
def plot_trace_C(ax= None ):
return az.plot_trace(
data= inf_data_C,
figsize= (FIGSIZE[0 ], 5 * 2 ),
var_names= ["~alpha" , "~beta" , "~y" , "~log_p" , "~p" ],
axes= ax,
)
plot_trace_C()
plt.tight_layout()
plt.show()
Table
az.summary(inf_data_C, kind= "all" , var_names= ["~y" , "~log_p" , "~p" ])
alpha
0.052
0.077
-0.095
0.195
0.000
0.000
28809.0
26061.0
1.0
alpha_p
0.037
0.055
-0.066
0.139
0.000
0.000
28310.0
25799.0
1.0
beta
1.007
0.057
0.902
1.112
0.000
0.000
24814.0
22155.0
1.0
bg_mean
-0.403
0.415
-1.186
0.419
0.003
0.002
22564.0
17542.0
1.0
bg_sigma
0.785
0.539
0.001
1.708
0.004
0.003
18153.0
10511.0
1.0
q
0.332
0.126
0.103
0.563
0.001
0.000
33729.0
24843.0
1.0
theta
0.788
0.029
0.735
0.839
0.000
0.000
24814.0
22155.0
1.0
Results
Figure
def plot_pair(* items, ax= None ):
for i, inf_data in enumerate (items):
ax = az.plot_pair(
data= inf_data,
figsize= FIGSIZE,
var_names= ["alpha" , "beta" ],
labeller= LABELLER,
kind= ["scatter" , "kde" ],
scatter_kwargs= {"color" : f"C { i} " },
marginals= True ,
marginal_kwargs= {"color" : f"C { i} " },
point_estimate= "mean" ,
reference_values= {
LABELLER.var_name_to_str("alpha" ): 0 ,
LABELLER.var_name_to_str("beta" ): 1 ,
},
textsize= TEXTSIZE,
ax= ax,
)
return ax
plot_pair(inf_data_A, inf_data_B, inf_data_C)
plt.show()
p_outlier = az.extract(inf_data_C, var_names= ["p" ]).mean("sample" )
Figure
def plot_samples(* items, p= None , ax= None ):
marker_style = {"marker" : "o" , "s" : 36 , "zorder" : 2 }
ax = plot_data(scatter_kwargs= marker_style, ax= ax)
x_min, x_max = - 2.8 , 2.8
x = np.linspace(x_min, x_max, 50 )
for i, inf_data in enumerate (items):
samples = az.extract(
data= inf_data, var_names= ["alpha" , "beta" ], num_samples= 100
).to_dataframe()
for _, d in samples.iterrows():
y = d["alpha" ] + d["beta" ] * x
ax.plot(x, y, c= f"C { i} " , alpha= 0.1 , zorder= 1 )
if p is not None :
marker_style = {** marker_style, "zorder" : 3 }
ax.scatter(data["x" ], data["y" ], c= p, cmap= "gray" , ec= "black" , ** marker_style)
ax.set_xlim(x_min, x_max)
return ax
plot_samples(inf_data_A, inf_data_B, inf_data_C, p= p_outlier)
plt.show()
comp_data = az.compare(
compare_dict= {"A" : inf_data_A, "B" : inf_data_B, "C" : inf_data_C},
ic= "loo" ,
method= "BB-pseudo-BMA" ,
seed= np_rng,
)
C
0
-14.548210
5.000393
0.000000
8.824918e-01
32.993495
0.000000
True
log
B
1
-18.289472
4.520744
3.741261
1.175078e-01
5.073979
2.878680
False
log
A
2
-72.965036
19.576818
58.416826
3.312725e-07
3.815517
32.273156
True
log
Figure
def plot_compare(ax= None ):
ax = az.plot_compare(
comp_df= comp_data,
figsize= (FIGSIZE[0 ], 0.5 * len (comp_data)),
plot_ic_diff= False ,
legend= False ,
title= False ,
textsize= TEXTSIZE,
ax= ax,
)
ax.set_xlabel("ELPD (LOO)" )
ax.set_ylabel("Model" )
return ax
plot_compare()
plt.show()
Watermark
Python implementation: CPython
Python version : 3.11.7
IPython version : 8.20.0
jax : 0.4.23
arviz : 0.17.0
matplotlib: 3.8.2
numpyro : 0.13.2
numpy : 1.26.3
pandas : 2.1.4